1 //===-- TargetRewrite.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Target rewrite: rewriting of ops to make target-specific lowerings manifest.
10 // LLVM expects different lowering idioms to be used for distinct target
11 // triples. These distinctions are handled by this pass.
12 //
13 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "PassDetail.h"
18 #include "Target.h"
19 #include "flang/Lower/Todo.h"
20 #include "flang/Optimizer/Builder/Character.h"
21 #include "flang/Optimizer/CodeGen/CodeGen.h"
22 #include "flang/Optimizer/Dialect/FIRDialect.h"
23 #include "flang/Optimizer/Dialect/FIROps.h"
24 #include "flang/Optimizer/Dialect/FIROpsSupport.h"
25 #include "flang/Optimizer/Dialect/FIRType.h"
26 #include "flang/Optimizer/Support/FIRContext.h"
27 #include "mlir/Transforms/DialectConversion.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/TypeSwitch.h"
30 #include "llvm/Support/Debug.h"
31 
32 using namespace fir;
33 
34 #define DEBUG_TYPE "flang-target-rewrite"
35 
36 namespace {
37 
38 /// Fixups for updating a FuncOp's arguments and return values.
39 struct FixupTy {
40   enum class Codes {
41     ArgumentAsLoad,
42     ArgumentType,
43     CharPair,
44     ReturnAsStore,
45     ReturnType,
46     Split,
47     Trailing,
48     TrailingCharProc
49   };
50 
51   FixupTy(Codes code, std::size_t index, std::size_t second = 0)
52       : code{code}, index{index}, second{second} {}
53   FixupTy(Codes code, std::size_t index,
54           std::function<void(mlir::FuncOp)> &&finalizer)
55       : code{code}, index{index}, finalizer{finalizer} {}
56   FixupTy(Codes code, std::size_t index, std::size_t second,
57           std::function<void(mlir::FuncOp)> &&finalizer)
58       : code{code}, index{index}, second{second}, finalizer{finalizer} {}
59 
60   Codes code;
61   std::size_t index;
62   std::size_t second{};
63   llvm::Optional<std::function<void(mlir::FuncOp)>> finalizer{};
64 }; // namespace
65 
66 /// Target-specific rewriting of the FIR. This is a prerequisite pass to code
67 /// generation that traverses the FIR and modifies types and operations to a
68 /// form that is appropriate for the specific target. LLVM IR has specific
69 /// idioms that are used for distinct target processor and ABI combinations.
70 class TargetRewrite : public TargetRewriteBase<TargetRewrite> {
71 public:
72   TargetRewrite(const TargetRewriteOptions &options) {
73     noCharacterConversion = options.noCharacterConversion;
74     noComplexConversion = options.noComplexConversion;
75   }
76 
77   void runOnOperation() override final {
78     auto &context = getContext();
79     mlir::OpBuilder rewriter(&context);
80 
81     auto mod = getModule();
82     if (!forcedTargetTriple.empty())
83       setTargetTriple(mod, forcedTargetTriple);
84 
85     auto specifics = CodeGenSpecifics::get(getOperation().getContext(),
86                                            getTargetTriple(getOperation()),
87                                            getKindMapping(getOperation()));
88     setMembers(specifics.get(), &rewriter);
89 
90     // Perform type conversion on signatures and call sites.
91     if (mlir::failed(convertTypes(mod))) {
92       mlir::emitError(mlir::UnknownLoc::get(&context),
93                       "error in converting types to target abi");
94       signalPassFailure();
95     }
96 
97     // Convert ops in target-specific patterns.
98     mod.walk([&](mlir::Operation *op) {
99       if (auto call = dyn_cast<fir::CallOp>(op)) {
100         if (!hasPortableSignature(call.getFunctionType()))
101           convertCallOp(call);
102       } else if (auto dispatch = dyn_cast<DispatchOp>(op)) {
103         if (!hasPortableSignature(dispatch.getFunctionType()))
104           convertCallOp(dispatch);
105       } else if (auto addr = dyn_cast<AddrOfOp>(op)) {
106         if (addr.getType().isa<mlir::FunctionType>() &&
107             !hasPortableSignature(addr.getType()))
108           convertAddrOp(addr);
109       }
110     });
111 
112     clearMembers();
113   }
114 
115   mlir::ModuleOp getModule() { return getOperation(); }
116 
117   template <typename A, typename B, typename C>
118   std::function<mlir::Value(mlir::Operation *)>
119   rewriteCallComplexResultType(A ty, B &newResTys, B &newInTys, C &newOpers) {
120     auto m = specifics->complexReturnType(ty.getElementType());
121     // Currently targets mandate COMPLEX is a single aggregate or packed
122     // scalar, including the sret case.
123     assert(m.size() == 1 && "target lowering of complex return not supported");
124     auto resTy = std::get<mlir::Type>(m[0]);
125     auto attr = std::get<CodeGenSpecifics::Attributes>(m[0]);
126     auto loc = mlir::UnknownLoc::get(resTy.getContext());
127     if (attr.isSRet()) {
128       assert(isa_ref_type(resTy));
129       mlir::Value stack =
130           rewriter->create<fir::AllocaOp>(loc, dyn_cast_ptrEleTy(resTy));
131       newInTys.push_back(resTy);
132       newOpers.push_back(stack);
133       return [=](mlir::Operation *) -> mlir::Value {
134         auto memTy = ReferenceType::get(ty);
135         auto cast = rewriter->create<ConvertOp>(loc, memTy, stack);
136         return rewriter->create<fir::LoadOp>(loc, cast);
137       };
138     }
139     newResTys.push_back(resTy);
140     return [=](mlir::Operation *call) -> mlir::Value {
141       auto mem = rewriter->create<fir::AllocaOp>(loc, resTy);
142       rewriter->create<fir::StoreOp>(loc, call->getResult(0), mem);
143       auto memTy = ReferenceType::get(ty);
144       auto cast = rewriter->create<ConvertOp>(loc, memTy, mem);
145       return rewriter->create<fir::LoadOp>(loc, cast);
146     };
147   }
148 
149   template <typename A, typename B, typename C>
150   void rewriteCallComplexInputType(A ty, mlir::Value oper, B &newInTys,
151                                    C &newOpers) {
152     auto m = specifics->complexArgumentType(ty.getElementType());
153     auto *ctx = ty.getContext();
154     auto loc = mlir::UnknownLoc::get(ctx);
155     if (m.size() == 1) {
156       // COMPLEX is a single aggregate
157       auto resTy = std::get<mlir::Type>(m[0]);
158       auto attr = std::get<CodeGenSpecifics::Attributes>(m[0]);
159       auto oldRefTy = ReferenceType::get(ty);
160       if (attr.isByVal()) {
161         auto mem = rewriter->create<fir::AllocaOp>(loc, ty);
162         rewriter->create<fir::StoreOp>(loc, oper, mem);
163         newOpers.push_back(rewriter->create<ConvertOp>(loc, resTy, mem));
164       } else {
165         auto mem = rewriter->create<fir::AllocaOp>(loc, resTy);
166         auto cast = rewriter->create<ConvertOp>(loc, oldRefTy, mem);
167         rewriter->create<fir::StoreOp>(loc, oper, cast);
168         newOpers.push_back(rewriter->create<fir::LoadOp>(loc, mem));
169       }
170       newInTys.push_back(resTy);
171     } else {
172       assert(m.size() == 2);
173       // COMPLEX is split into 2 separate arguments
174       for (auto e : llvm::enumerate(m)) {
175         auto &tup = e.value();
176         auto ty = std::get<mlir::Type>(tup);
177         auto index = e.index();
178         auto idx = rewriter->getIntegerAttr(rewriter->getIndexType(), index);
179         auto val = rewriter->create<ExtractValueOp>(
180             loc, ty, oper, rewriter->getArrayAttr(idx));
181         newInTys.push_back(ty);
182         newOpers.push_back(val);
183       }
184     }
185   }
186 
187   // Convert fir.call and fir.dispatch Ops.
188   template <typename A>
189   void convertCallOp(A callOp) {
190     auto fnTy = callOp.getFunctionType();
191     auto loc = callOp.getLoc();
192     rewriter->setInsertionPoint(callOp);
193     llvm::SmallVector<mlir::Type> newResTys;
194     llvm::SmallVector<mlir::Type> newInTys;
195     llvm::SmallVector<mlir::Value> newOpers;
196 
197     // If the call is indirect, the first argument must still be the function
198     // to call.
199     int dropFront = 0;
200     if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
201       if (!callOp.callee().hasValue()) {
202         newInTys.push_back(fnTy.getInput(0));
203         newOpers.push_back(callOp.getOperand(0));
204         dropFront = 1;
205       }
206     }
207 
208     // Determine the rewrite function, `wrap`, for the result value.
209     llvm::Optional<std::function<mlir::Value(mlir::Operation *)>> wrap;
210     if (fnTy.getResults().size() == 1) {
211       mlir::Type ty = fnTy.getResult(0);
212       llvm::TypeSwitch<mlir::Type>(ty)
213           .template Case<fir::ComplexType>([&](fir::ComplexType cmplx) {
214             wrap = rewriteCallComplexResultType(cmplx, newResTys, newInTys,
215                                                 newOpers);
216           })
217           .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
218             wrap = rewriteCallComplexResultType(cmplx, newResTys, newInTys,
219                                                 newOpers);
220           })
221           .Default([&](mlir::Type ty) { newResTys.push_back(ty); });
222     } else if (fnTy.getResults().size() > 1) {
223       TODO(loc, "multiple results not supported yet");
224     }
225 
226     llvm::SmallVector<mlir::Type> trailingInTys;
227     llvm::SmallVector<mlir::Value> trailingOpers;
228     for (auto e : llvm::enumerate(
229              llvm::zip(fnTy.getInputs().drop_front(dropFront),
230                        callOp.getOperands().drop_front(dropFront)))) {
231       mlir::Type ty = std::get<0>(e.value());
232       mlir::Value oper = std::get<1>(e.value());
233       unsigned index = e.index();
234       llvm::TypeSwitch<mlir::Type>(ty)
235           .template Case<BoxCharType>([&](BoxCharType boxTy) {
236             bool sret;
237             if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
238               sret = callOp.callee() &&
239                      functionArgIsSRet(index,
240                                        getModule().lookupSymbol<mlir::FuncOp>(
241                                            *callOp.callee()));
242             } else {
243               // TODO: dispatch case; how do we put arguments on a call?
244               // We cannot put both an sret and the dispatch object first.
245               sret = false;
246               TODO(loc, "dispatch + sret not supported yet");
247             }
248             auto m = specifics->boxcharArgumentType(boxTy.getEleTy(), sret);
249             auto unbox =
250                 rewriter->create<UnboxCharOp>(loc, std::get<mlir::Type>(m[0]),
251                                               std::get<mlir::Type>(m[1]), oper);
252             // unboxed CHARACTER arguments
253             for (auto e : llvm::enumerate(m)) {
254               unsigned idx = e.index();
255               auto attr = std::get<CodeGenSpecifics::Attributes>(e.value());
256               auto argTy = std::get<mlir::Type>(e.value());
257               if (attr.isAppend()) {
258                 trailingInTys.push_back(argTy);
259                 trailingOpers.push_back(unbox.getResult(idx));
260               } else {
261                 newInTys.push_back(argTy);
262                 newOpers.push_back(unbox.getResult(idx));
263               }
264             }
265           })
266           .template Case<fir::ComplexType>([&](fir::ComplexType cmplx) {
267             rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers);
268           })
269           .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
270             rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers);
271           })
272           .template Case<mlir::TupleType>([&](mlir::TupleType tuple) {
273             if (factory::isCharacterProcedureTuple(tuple)) {
274               mlir::ModuleOp module = getModule();
275               if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
276                 if (callOp.callee()) {
277                   llvm::StringRef charProcAttr =
278                       fir::getCharacterProcedureDummyAttrName();
279                   // The charProcAttr attribute is only used as a safety to
280                   // confirm that this is a dummy procedure and should be split.
281                   // It cannot be used to match because attributes are not
282                   // available in case of indirect calls.
283                   auto funcOp =
284                       module.lookupSymbol<mlir::FuncOp>(*callOp.callee());
285                   if (funcOp &&
286                       !funcOp.template getArgAttrOfType<mlir::UnitAttr>(
287                           index, charProcAttr))
288                     mlir::emitError(loc, "tuple argument will be split even "
289                                          "though it does not have the `" +
290                                              charProcAttr + "` attribute");
291                 }
292               }
293               mlir::Type funcPointerType = tuple.getType(0);
294               mlir::Type lenType = tuple.getType(1);
295               FirOpBuilder builder(*rewriter, getKindMapping(module));
296               auto [funcPointer, len] =
297                   factory::extractCharacterProcedureTuple(builder, loc, oper);
298               newInTys.push_back(funcPointerType);
299               newOpers.push_back(funcPointer);
300               trailingInTys.push_back(lenType);
301               trailingOpers.push_back(len);
302             } else {
303               newInTys.push_back(tuple);
304               newOpers.push_back(oper);
305             }
306           })
307           .Default([&](mlir::Type ty) {
308             newInTys.push_back(ty);
309             newOpers.push_back(oper);
310           });
311     }
312     newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end());
313     newOpers.insert(newOpers.end(), trailingOpers.begin(), trailingOpers.end());
314     if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
315       fir::CallOp newCall;
316       if (callOp.callee().hasValue()) {
317         newCall = rewriter->create<A>(loc, callOp.callee().getValue(),
318                                       newResTys, newOpers);
319       } else {
320         // Force new type on the input operand.
321         newOpers[0].setType(mlir::FunctionType::get(
322             callOp.getContext(),
323             mlir::TypeRange{newInTys}.drop_front(dropFront), newResTys));
324         newCall = rewriter->create<A>(loc, newResTys, newOpers);
325       }
326       LLVM_DEBUG(llvm::dbgs() << "replacing call with " << newCall << '\n');
327       if (wrap.hasValue())
328         replaceOp(callOp, (*wrap)(newCall.getOperation()));
329       else
330         replaceOp(callOp, newCall.getResults());
331     } else {
332       // A is fir::DispatchOp
333       TODO(loc, "dispatch not implemented");
334     }
335   }
336 
337   // Result type fixup for fir::ComplexType and mlir::ComplexType
338   template <typename A, typename B>
339   void lowerComplexSignatureRes(A cmplx, B &newResTys, B &newInTys) {
340     if (noComplexConversion) {
341       newResTys.push_back(cmplx);
342     } else {
343       for (auto &tup : specifics->complexReturnType(cmplx.getElementType())) {
344         auto argTy = std::get<mlir::Type>(tup);
345         if (std::get<CodeGenSpecifics::Attributes>(tup).isSRet())
346           newInTys.push_back(argTy);
347         else
348           newResTys.push_back(argTy);
349       }
350     }
351   }
352 
353   // Argument type fixup for fir::ComplexType and mlir::ComplexType
354   template <typename A, typename B>
355   void lowerComplexSignatureArg(A cmplx, B &newInTys) {
356     if (noComplexConversion)
357       newInTys.push_back(cmplx);
358     else
359       for (auto &tup : specifics->complexArgumentType(cmplx.getElementType()))
360         newInTys.push_back(std::get<mlir::Type>(tup));
361   }
362 
363   /// Taking the address of a function. Modify the signature as needed.
364   void convertAddrOp(AddrOfOp addrOp) {
365     rewriter->setInsertionPoint(addrOp);
366     auto addrTy = addrOp.getType().cast<mlir::FunctionType>();
367     llvm::SmallVector<mlir::Type> newResTys;
368     llvm::SmallVector<mlir::Type> newInTys;
369     for (mlir::Type ty : addrTy.getResults()) {
370       llvm::TypeSwitch<mlir::Type>(ty)
371           .Case<fir::ComplexType>([&](fir::ComplexType ty) {
372             lowerComplexSignatureRes(ty, newResTys, newInTys);
373           })
374           .Case<mlir::ComplexType>([&](mlir::ComplexType ty) {
375             lowerComplexSignatureRes(ty, newResTys, newInTys);
376           })
377           .Default([&](mlir::Type ty) { newResTys.push_back(ty); });
378     }
379     llvm::SmallVector<mlir::Type> trailingInTys;
380     for (mlir::Type ty : addrTy.getInputs()) {
381       llvm::TypeSwitch<mlir::Type>(ty)
382           .Case<BoxCharType>([&](BoxCharType box) {
383             if (noCharacterConversion) {
384               newInTys.push_back(box);
385             } else {
386               for (auto &tup : specifics->boxcharArgumentType(box.getEleTy())) {
387                 auto attr = std::get<CodeGenSpecifics::Attributes>(tup);
388                 auto argTy = std::get<mlir::Type>(tup);
389                 llvm::SmallVector<mlir::Type> &vec =
390                     attr.isAppend() ? trailingInTys : newInTys;
391                 vec.push_back(argTy);
392               }
393             }
394           })
395           .Case<fir::ComplexType>([&](fir::ComplexType ty) {
396             lowerComplexSignatureArg(ty, newInTys);
397           })
398           .Case<mlir::ComplexType>([&](mlir::ComplexType ty) {
399             lowerComplexSignatureArg(ty, newInTys);
400           })
401           .Case<mlir::TupleType>([&](mlir::TupleType tuple) {
402             if (factory::isCharacterProcedureTuple(tuple)) {
403               newInTys.push_back(tuple.getType(0));
404               trailingInTys.push_back(tuple.getType(1));
405             } else {
406               newInTys.push_back(ty);
407             }
408           })
409           .Default([&](mlir::Type ty) { newInTys.push_back(ty); });
410     }
411     // append trailing input types
412     newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end());
413     // replace this op with a new one with the updated signature
414     auto newTy = rewriter->getFunctionType(newInTys, newResTys);
415     auto newOp =
416         rewriter->create<AddrOfOp>(addrOp.getLoc(), newTy, addrOp.symbol());
417     replaceOp(addrOp, newOp.getResult());
418   }
419 
420   /// Convert the type signatures on all the functions present in the module.
421   /// As the type signature is being changed, this must also update the
422   /// function itself to use any new arguments, etc.
423   mlir::LogicalResult convertTypes(mlir::ModuleOp mod) {
424     for (auto fn : mod.getOps<mlir::FuncOp>())
425       convertSignature(fn);
426     return mlir::success();
427   }
428 
429   /// If the signature does not need any special target-specific converions,
430   /// then it is considered portable for any target, and this function will
431   /// return `true`. Otherwise, the signature is not portable and `false` is
432   /// returned.
433   bool hasPortableSignature(mlir::Type signature) {
434     assert(signature.isa<mlir::FunctionType>());
435     auto func = signature.dyn_cast<mlir::FunctionType>();
436     for (auto ty : func.getResults())
437       if ((ty.isa<BoxCharType>() && !noCharacterConversion) ||
438           (isa_complex(ty) && !noComplexConversion)) {
439         LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n");
440         return false;
441       }
442     for (auto ty : func.getInputs())
443       if (((ty.isa<BoxCharType>() || factory::isCharacterProcedureTuple(ty)) &&
444            !noCharacterConversion) ||
445           (isa_complex(ty) && !noComplexConversion)) {
446         LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n");
447         return false;
448       }
449     return true;
450   }
451 
452   /// Rewrite the signatures and body of the `FuncOp`s in the module for
453   /// the immediately subsequent target code gen.
454   void convertSignature(mlir::FuncOp func) {
455     auto funcTy = func.getType().cast<mlir::FunctionType>();
456     if (hasPortableSignature(funcTy))
457       return;
458     llvm::SmallVector<mlir::Type> newResTys;
459     llvm::SmallVector<mlir::Type> newInTys;
460     llvm::SmallVector<FixupTy> fixups;
461 
462     // Convert return value(s)
463     for (auto ty : funcTy.getResults())
464       llvm::TypeSwitch<mlir::Type>(ty)
465           .Case<fir::ComplexType>([&](fir::ComplexType cmplx) {
466             if (noComplexConversion)
467               newResTys.push_back(cmplx);
468             else
469               doComplexReturn(func, cmplx, newResTys, newInTys, fixups);
470           })
471           .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
472             if (noComplexConversion)
473               newResTys.push_back(cmplx);
474             else
475               doComplexReturn(func, cmplx, newResTys, newInTys, fixups);
476           })
477           .Default([&](mlir::Type ty) { newResTys.push_back(ty); });
478 
479     // Convert arguments
480     llvm::SmallVector<mlir::Type> trailingTys;
481     for (auto e : llvm::enumerate(funcTy.getInputs())) {
482       auto ty = e.value();
483       unsigned index = e.index();
484       llvm::TypeSwitch<mlir::Type>(ty)
485           .Case<BoxCharType>([&](BoxCharType boxTy) {
486             if (noCharacterConversion) {
487               newInTys.push_back(boxTy);
488             } else {
489               // Convert a CHARACTER argument type. This can involve separating
490               // the pointer and the LEN into two arguments and moving the LEN
491               // argument to the end of the arg list.
492               bool sret = functionArgIsSRet(index, func);
493               for (auto e : llvm::enumerate(specifics->boxcharArgumentType(
494                        boxTy.getEleTy(), sret))) {
495                 auto &tup = e.value();
496                 auto index = e.index();
497                 auto attr = std::get<CodeGenSpecifics::Attributes>(tup);
498                 auto argTy = std::get<mlir::Type>(tup);
499                 if (attr.isAppend()) {
500                   trailingTys.push_back(argTy);
501                 } else {
502                   if (sret) {
503                     fixups.emplace_back(FixupTy::Codes::CharPair,
504                                         newInTys.size(), index);
505                   } else {
506                     fixups.emplace_back(FixupTy::Codes::Trailing,
507                                         newInTys.size(), trailingTys.size());
508                   }
509                   newInTys.push_back(argTy);
510                 }
511               }
512             }
513           })
514           .Case<fir::ComplexType>([&](fir::ComplexType cmplx) {
515             if (noComplexConversion)
516               newInTys.push_back(cmplx);
517             else
518               doComplexArg(func, cmplx, newInTys, fixups);
519           })
520           .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
521             if (noComplexConversion)
522               newInTys.push_back(cmplx);
523             else
524               doComplexArg(func, cmplx, newInTys, fixups);
525           })
526           .Case<mlir::TupleType>([&](mlir::TupleType tuple) {
527             if (factory::isCharacterProcedureTuple(tuple)) {
528               fixups.emplace_back(FixupTy::Codes::TrailingCharProc,
529                                   newInTys.size(), trailingTys.size());
530               newInTys.push_back(tuple.getType(0));
531               trailingTys.push_back(tuple.getType(1));
532             } else {
533               newInTys.push_back(ty);
534             }
535           })
536           .Default([&](mlir::Type ty) { newInTys.push_back(ty); });
537     }
538 
539     if (!func.empty()) {
540       // If the function has a body, then apply the fixups to the arguments and
541       // return ops as required. These fixups are done in place.
542       auto loc = func.getLoc();
543       const auto fixupSize = fixups.size();
544       const auto oldArgTys = func.getType().getInputs();
545       int offset = 0;
546       for (std::remove_const_t<decltype(fixupSize)> i = 0; i < fixupSize; ++i) {
547         const auto &fixup = fixups[i];
548         switch (fixup.code) {
549         case FixupTy::Codes::ArgumentAsLoad: {
550           // Argument was pass-by-value, but is now pass-by-reference and
551           // possibly with a different element type.
552           auto newArg = func.front().insertArgument(fixup.index,
553                                                     newInTys[fixup.index], loc);
554           rewriter->setInsertionPointToStart(&func.front());
555           auto oldArgTy = ReferenceType::get(oldArgTys[fixup.index - offset]);
556           auto cast = rewriter->create<ConvertOp>(loc, oldArgTy, newArg);
557           auto load = rewriter->create<fir::LoadOp>(loc, cast);
558           func.getArgument(fixup.index + 1).replaceAllUsesWith(load);
559           func.front().eraseArgument(fixup.index + 1);
560         } break;
561         case FixupTy::Codes::ArgumentType: {
562           // Argument is pass-by-value, but its type has likely been modified to
563           // suit the target ABI convention.
564           auto newArg = func.front().insertArgument(fixup.index,
565                                                     newInTys[fixup.index], loc);
566           rewriter->setInsertionPointToStart(&func.front());
567           auto mem =
568               rewriter->create<fir::AllocaOp>(loc, newInTys[fixup.index]);
569           rewriter->create<fir::StoreOp>(loc, newArg, mem);
570           auto oldArgTy = ReferenceType::get(oldArgTys[fixup.index - offset]);
571           auto cast = rewriter->create<ConvertOp>(loc, oldArgTy, mem);
572           mlir::Value load = rewriter->create<fir::LoadOp>(loc, cast);
573           func.getArgument(fixup.index + 1).replaceAllUsesWith(load);
574           func.front().eraseArgument(fixup.index + 1);
575           LLVM_DEBUG(llvm::dbgs()
576                      << "old argument: " << oldArgTy.getEleTy()
577                      << ", repl: " << load << ", new argument: "
578                      << func.getArgument(fixup.index).getType() << '\n');
579         } break;
580         case FixupTy::Codes::CharPair: {
581           // The FIR boxchar argument has been split into a pair of distinct
582           // arguments that are in juxtaposition to each other.
583           auto newArg = func.front().insertArgument(fixup.index,
584                                                     newInTys[fixup.index], loc);
585           if (fixup.second == 1) {
586             rewriter->setInsertionPointToStart(&func.front());
587             auto boxTy = oldArgTys[fixup.index - offset - fixup.second];
588             auto box = rewriter->create<EmboxCharOp>(
589                 loc, boxTy, func.front().getArgument(fixup.index - 1), newArg);
590             func.getArgument(fixup.index + 1).replaceAllUsesWith(box);
591             func.front().eraseArgument(fixup.index + 1);
592             offset++;
593           }
594         } break;
595         case FixupTy::Codes::ReturnAsStore: {
596           // The value being returned is now being returned in memory (callee
597           // stack space) through a hidden reference argument.
598           auto newArg = func.front().insertArgument(fixup.index,
599                                                     newInTys[fixup.index], loc);
600           offset++;
601           func.walk([&](mlir::ReturnOp ret) {
602             rewriter->setInsertionPoint(ret);
603             auto oldOper = ret.getOperand(0);
604             auto oldOperTy = ReferenceType::get(oldOper.getType());
605             auto cast = rewriter->create<ConvertOp>(loc, oldOperTy, newArg);
606             rewriter->create<fir::StoreOp>(loc, oldOper, cast);
607             rewriter->create<mlir::ReturnOp>(loc);
608             ret.erase();
609           });
610         } break;
611         case FixupTy::Codes::ReturnType: {
612           // The function is still returning a value, but its type has likely
613           // changed to suit the target ABI convention.
614           func.walk([&](mlir::ReturnOp ret) {
615             rewriter->setInsertionPoint(ret);
616             auto oldOper = ret.getOperand(0);
617             auto oldOperTy = ReferenceType::get(oldOper.getType());
618             auto mem =
619                 rewriter->create<fir::AllocaOp>(loc, newResTys[fixup.index]);
620             auto cast = rewriter->create<ConvertOp>(loc, oldOperTy, mem);
621             rewriter->create<fir::StoreOp>(loc, oldOper, cast);
622             mlir::Value load = rewriter->create<fir::LoadOp>(loc, mem);
623             rewriter->create<mlir::ReturnOp>(loc, load);
624             ret.erase();
625           });
626         } break;
627         case FixupTy::Codes::Split: {
628           // The FIR argument has been split into a pair of distinct arguments
629           // that are in juxtaposition to each other. (For COMPLEX value.)
630           auto newArg = func.front().insertArgument(fixup.index,
631                                                     newInTys[fixup.index], loc);
632           if (fixup.second == 1) {
633             rewriter->setInsertionPointToStart(&func.front());
634             auto cplxTy = oldArgTys[fixup.index - offset - fixup.second];
635             auto undef = rewriter->create<UndefOp>(loc, cplxTy);
636             auto zero = rewriter->getIntegerAttr(rewriter->getIndexType(), 0);
637             auto one = rewriter->getIntegerAttr(rewriter->getIndexType(), 1);
638             auto cplx1 = rewriter->create<InsertValueOp>(
639                 loc, cplxTy, undef, func.front().getArgument(fixup.index - 1),
640                 rewriter->getArrayAttr(zero));
641             auto cplx = rewriter->create<InsertValueOp>(
642                 loc, cplxTy, cplx1, newArg, rewriter->getArrayAttr(one));
643             func.getArgument(fixup.index + 1).replaceAllUsesWith(cplx);
644             func.front().eraseArgument(fixup.index + 1);
645             offset++;
646           }
647         } break;
648         case FixupTy::Codes::Trailing: {
649           // The FIR argument has been split into a pair of distinct arguments.
650           // The first part of the pair appears in the original argument
651           // position. The second part of the pair is appended after all the
652           // original arguments. (Boxchar arguments.)
653           auto newBufArg = func.front().insertArgument(
654               fixup.index, newInTys[fixup.index], loc);
655           auto newLenArg =
656               func.front().addArgument(trailingTys[fixup.second], loc);
657           auto boxTy = oldArgTys[fixup.index - offset];
658           rewriter->setInsertionPointToStart(&func.front());
659           auto box =
660               rewriter->create<EmboxCharOp>(loc, boxTy, newBufArg, newLenArg);
661           func.getArgument(fixup.index + 1).replaceAllUsesWith(box);
662           func.front().eraseArgument(fixup.index + 1);
663         } break;
664         case FixupTy::Codes::TrailingCharProc: {
665           // The FIR character procedure argument tuple has been split into a
666           // pair of distinct arguments. The first part of the pair appears in
667           // the original argument position. The second part of the pair is
668           // appended after all the original arguments.
669           auto newProcPointerArg = func.front().insertArgument(
670               fixup.index, newInTys[fixup.index], loc);
671           auto newLenArg =
672               func.front().addArgument(trailingTys[fixup.second], loc);
673           auto tupleType = oldArgTys[fixup.index - offset];
674           rewriter->setInsertionPointToStart(&func.front());
675           FirOpBuilder builder(*rewriter, getKindMapping(getModule()));
676           auto tuple = factory::createCharacterProcedureTuple(
677               builder, loc, tupleType, newProcPointerArg, newLenArg);
678           func.getArgument(fixup.index + 1).replaceAllUsesWith(tuple);
679           func.front().eraseArgument(fixup.index + 1);
680         } break;
681         }
682       }
683     }
684 
685     // Set the new type and finalize the arguments, etc.
686     newInTys.insert(newInTys.end(), trailingTys.begin(), trailingTys.end());
687     auto newFuncTy =
688         mlir::FunctionType::get(func.getContext(), newInTys, newResTys);
689     LLVM_DEBUG(llvm::dbgs() << "new func: " << newFuncTy << '\n');
690     func.setType(newFuncTy);
691 
692     for (auto &fixup : fixups)
693       if (fixup.finalizer)
694         (*fixup.finalizer)(func);
695   }
696 
697   inline bool functionArgIsSRet(unsigned index, mlir::FuncOp func) {
698     if (auto attr = func.getArgAttrOfType<mlir::UnitAttr>(index, "llvm.sret"))
699       return true;
700     return false;
701   }
702 
703   /// Convert a complex return value. This can involve converting the return
704   /// value to a "hidden" first argument or packing the complex into a wide
705   /// GPR.
706   template <typename A, typename B, typename C>
707   void doComplexReturn(mlir::FuncOp func, A cmplx, B &newResTys, B &newInTys,
708                        C &fixups) {
709     if (noComplexConversion) {
710       newResTys.push_back(cmplx);
711       return;
712     }
713     auto m = specifics->complexReturnType(cmplx.getElementType());
714     assert(m.size() == 1);
715     auto &tup = m[0];
716     auto attr = std::get<CodeGenSpecifics::Attributes>(tup);
717     auto argTy = std::get<mlir::Type>(tup);
718     if (attr.isSRet()) {
719       unsigned argNo = newInTys.size();
720       fixups.emplace_back(
721           FixupTy::Codes::ReturnAsStore, argNo, [=](mlir::FuncOp func) {
722             func.setArgAttr(argNo, "llvm.sret", rewriter->getUnitAttr());
723           });
724       newInTys.push_back(argTy);
725       return;
726     }
727     fixups.emplace_back(FixupTy::Codes::ReturnType, newResTys.size());
728     newResTys.push_back(argTy);
729   }
730 
731   /// Convert a complex argument value. This can involve storing the value to
732   /// a temporary memory location or factoring the value into two distinct
733   /// arguments.
734   template <typename A, typename B, typename C>
735   void doComplexArg(mlir::FuncOp func, A cmplx, B &newInTys, C &fixups) {
736     if (noComplexConversion) {
737       newInTys.push_back(cmplx);
738       return;
739     }
740     auto m = specifics->complexArgumentType(cmplx.getElementType());
741     const auto fixupCode =
742         m.size() > 1 ? FixupTy::Codes::Split : FixupTy::Codes::ArgumentType;
743     for (auto e : llvm::enumerate(m)) {
744       auto &tup = e.value();
745       auto index = e.index();
746       auto attr = std::get<CodeGenSpecifics::Attributes>(tup);
747       auto argTy = std::get<mlir::Type>(tup);
748       auto argNo = newInTys.size();
749       if (attr.isByVal()) {
750         if (auto align = attr.getAlignment())
751           fixups.emplace_back(
752               FixupTy::Codes::ArgumentAsLoad, argNo, [=](mlir::FuncOp func) {
753                 func.setArgAttr(argNo, "llvm.byval", rewriter->getUnitAttr());
754                 func.setArgAttr(argNo, "llvm.align",
755                                 rewriter->getIntegerAttr(
756                                     rewriter->getIntegerType(32), align));
757               });
758         else
759           fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad, newInTys.size(),
760                               [=](mlir::FuncOp func) {
761                                 func.setArgAttr(argNo, "llvm.byval",
762                                                 rewriter->getUnitAttr());
763                               });
764       } else {
765         if (auto align = attr.getAlignment())
766           fixups.emplace_back(fixupCode, argNo, index, [=](mlir::FuncOp func) {
767             func.setArgAttr(
768                 argNo, "llvm.align",
769                 rewriter->getIntegerAttr(rewriter->getIntegerType(32), align));
770           });
771         else
772           fixups.emplace_back(fixupCode, argNo, index);
773       }
774       newInTys.push_back(argTy);
775     }
776   }
777 
778 private:
779   // Replace `op` and remove it.
780   void replaceOp(mlir::Operation *op, mlir::ValueRange newValues) {
781     op->replaceAllUsesWith(newValues);
782     op->dropAllReferences();
783     op->erase();
784   }
785 
786   inline void setMembers(CodeGenSpecifics *s, mlir::OpBuilder *r) {
787     specifics = s;
788     rewriter = r;
789   }
790 
791   inline void clearMembers() { setMembers(nullptr, nullptr); }
792 
793   CodeGenSpecifics *specifics{};
794   mlir::OpBuilder *rewriter;
795 }; // namespace
796 } // namespace
797 
798 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
799 fir::createFirTargetRewritePass(const TargetRewriteOptions &options) {
800   return std::make_unique<TargetRewrite>(options);
801 }
802