1 //===-- TargetRewrite.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Target rewrite: rewriting of ops to make target-specific lowerings manifest.
10 // LLVM expects different lowering idioms to be used for distinct target
11 // triples. These distinctions are handled by this pass.
12 //
13 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "PassDetail.h"
18 #include "Target.h"
19 #include "flang/Lower/Todo.h"
20 #include "flang/Optimizer/Builder/Character.h"
21 #include "flang/Optimizer/CodeGen/CodeGen.h"
22 #include "flang/Optimizer/Dialect/FIRDialect.h"
23 #include "flang/Optimizer/Dialect/FIROps.h"
24 #include "flang/Optimizer/Dialect/FIROpsSupport.h"
25 #include "flang/Optimizer/Dialect/FIRType.h"
26 #include "flang/Optimizer/Support/FIRContext.h"
27 #include "mlir/Transforms/DialectConversion.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/TypeSwitch.h"
30 #include "llvm/Support/Debug.h"
31 
32 using namespace fir;
33 
34 #define DEBUG_TYPE "flang-target-rewrite"
35 
36 namespace {
37 
38 /// Fixups for updating a FuncOp's arguments and return values.
39 struct FixupTy {
40   enum class Codes {
41     ArgumentAsLoad,
42     ArgumentType,
43     CharPair,
44     ReturnAsStore,
45     ReturnType,
46     Split,
47     Trailing,
48     TrailingCharProc
49   };
50 
51   FixupTy(Codes code, std::size_t index, std::size_t second = 0)
52       : code{code}, index{index}, second{second} {}
53   FixupTy(Codes code, std::size_t index,
54           std::function<void(mlir::FuncOp)> &&finalizer)
55       : code{code}, index{index}, finalizer{finalizer} {}
56   FixupTy(Codes code, std::size_t index, std::size_t second,
57           std::function<void(mlir::FuncOp)> &&finalizer)
58       : code{code}, index{index}, second{second}, finalizer{finalizer} {}
59 
60   Codes code;
61   std::size_t index;
62   std::size_t second{};
63   llvm::Optional<std::function<void(mlir::FuncOp)>> finalizer{};
64 }; // namespace
65 
66 /// Target-specific rewriting of the FIR. This is a prerequisite pass to code
67 /// generation that traverses the FIR and modifies types and operations to a
68 /// form that is appropriate for the specific target. LLVM IR has specific
69 /// idioms that are used for distinct target processor and ABI combinations.
70 class TargetRewrite : public TargetRewriteBase<TargetRewrite> {
71 public:
72   TargetRewrite(const TargetRewriteOptions &options) {
73     noCharacterConversion = options.noCharacterConversion;
74     noComplexConversion = options.noComplexConversion;
75   }
76 
77   void runOnOperation() override final {
78     auto &context = getContext();
79     mlir::OpBuilder rewriter(&context);
80 
81     auto mod = getModule();
82     if (!forcedTargetTriple.empty())
83       setTargetTriple(mod, forcedTargetTriple);
84 
85     auto specifics = CodeGenSpecifics::get(getOperation().getContext(),
86                                            getTargetTriple(getOperation()),
87                                            getKindMapping(getOperation()));
88     setMembers(specifics.get(), &rewriter);
89 
90     // Perform type conversion on signatures and call sites.
91     if (mlir::failed(convertTypes(mod))) {
92       mlir::emitError(mlir::UnknownLoc::get(&context),
93                       "error in converting types to target abi");
94       signalPassFailure();
95     }
96 
97     // Convert ops in target-specific patterns.
98     mod.walk([&](mlir::Operation *op) {
99       if (auto call = dyn_cast<fir::CallOp>(op)) {
100         if (!hasPortableSignature(call.getFunctionType()))
101           convertCallOp(call);
102       } else if (auto dispatch = dyn_cast<DispatchOp>(op)) {
103         if (!hasPortableSignature(dispatch.getFunctionType()))
104           convertCallOp(dispatch);
105       } else if (auto addr = dyn_cast<AddrOfOp>(op)) {
106         if (addr.getType().isa<mlir::FunctionType>() &&
107             !hasPortableSignature(addr.getType()))
108           convertAddrOp(addr);
109       }
110     });
111 
112     clearMembers();
113   }
114 
115   mlir::ModuleOp getModule() { return getOperation(); }
116 
117   template <typename A, typename B, typename C>
118   std::function<mlir::Value(mlir::Operation *)>
119   rewriteCallComplexResultType(A ty, B &newResTys, B &newInTys, C &newOpers) {
120     auto m = specifics->complexReturnType(ty.getElementType());
121     // Currently targets mandate COMPLEX is a single aggregate or packed
122     // scalar, including the sret case.
123     assert(m.size() == 1 && "target lowering of complex return not supported");
124     auto resTy = std::get<mlir::Type>(m[0]);
125     auto attr = std::get<CodeGenSpecifics::Attributes>(m[0]);
126     auto loc = mlir::UnknownLoc::get(resTy.getContext());
127     if (attr.isSRet()) {
128       assert(isa_ref_type(resTy));
129       mlir::Value stack =
130           rewriter->create<fir::AllocaOp>(loc, dyn_cast_ptrEleTy(resTy));
131       newInTys.push_back(resTy);
132       newOpers.push_back(stack);
133       return [=](mlir::Operation *) -> mlir::Value {
134         auto memTy = ReferenceType::get(ty);
135         auto cast = rewriter->create<ConvertOp>(loc, memTy, stack);
136         return rewriter->create<fir::LoadOp>(loc, cast);
137       };
138     }
139     newResTys.push_back(resTy);
140     return [=](mlir::Operation *call) -> mlir::Value {
141       auto mem = rewriter->create<fir::AllocaOp>(loc, resTy);
142       rewriter->create<fir::StoreOp>(loc, call->getResult(0), mem);
143       auto memTy = ReferenceType::get(ty);
144       auto cast = rewriter->create<ConvertOp>(loc, memTy, mem);
145       return rewriter->create<fir::LoadOp>(loc, cast);
146     };
147   }
148 
149   template <typename A, typename B, typename C>
150   void rewriteCallComplexInputType(A ty, mlir::Value oper, B &newInTys,
151                                    C &newOpers) {
152     auto m = specifics->complexArgumentType(ty.getElementType());
153     auto *ctx = ty.getContext();
154     auto loc = mlir::UnknownLoc::get(ctx);
155     if (m.size() == 1) {
156       // COMPLEX is a single aggregate
157       auto resTy = std::get<mlir::Type>(m[0]);
158       auto attr = std::get<CodeGenSpecifics::Attributes>(m[0]);
159       auto oldRefTy = ReferenceType::get(ty);
160       if (attr.isByVal()) {
161         auto mem = rewriter->create<fir::AllocaOp>(loc, ty);
162         rewriter->create<fir::StoreOp>(loc, oper, mem);
163         newOpers.push_back(rewriter->create<ConvertOp>(loc, resTy, mem));
164       } else {
165         auto mem = rewriter->create<fir::AllocaOp>(loc, resTy);
166         auto cast = rewriter->create<ConvertOp>(loc, oldRefTy, mem);
167         rewriter->create<fir::StoreOp>(loc, oper, cast);
168         newOpers.push_back(rewriter->create<fir::LoadOp>(loc, mem));
169       }
170       newInTys.push_back(resTy);
171     } else {
172       assert(m.size() == 2);
173       // COMPLEX is split into 2 separate arguments
174       auto iTy = rewriter->getIntegerType(32);
175       for (auto e : llvm::enumerate(m)) {
176         auto &tup = e.value();
177         auto ty = std::get<mlir::Type>(tup);
178         auto index = e.index();
179         auto idx = rewriter->getIntegerAttr(iTy, index);
180         auto val = rewriter->create<ExtractValueOp>(
181             loc, ty, oper, rewriter->getArrayAttr(idx));
182         newInTys.push_back(ty);
183         newOpers.push_back(val);
184       }
185     }
186   }
187 
188   // Convert fir.call and fir.dispatch Ops.
189   template <typename A>
190   void convertCallOp(A callOp) {
191     auto fnTy = callOp.getFunctionType();
192     auto loc = callOp.getLoc();
193     rewriter->setInsertionPoint(callOp);
194     llvm::SmallVector<mlir::Type> newResTys;
195     llvm::SmallVector<mlir::Type> newInTys;
196     llvm::SmallVector<mlir::Value> newOpers;
197 
198     // If the call is indirect, the first argument must still be the function
199     // to call.
200     int dropFront = 0;
201     if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
202       if (!callOp.callee().hasValue()) {
203         newInTys.push_back(fnTy.getInput(0));
204         newOpers.push_back(callOp.getOperand(0));
205         dropFront = 1;
206       }
207     }
208 
209     // Determine the rewrite function, `wrap`, for the result value.
210     llvm::Optional<std::function<mlir::Value(mlir::Operation *)>> wrap;
211     if (fnTy.getResults().size() == 1) {
212       mlir::Type ty = fnTy.getResult(0);
213       llvm::TypeSwitch<mlir::Type>(ty)
214           .template Case<fir::ComplexType>([&](fir::ComplexType cmplx) {
215             wrap = rewriteCallComplexResultType(cmplx, newResTys, newInTys,
216                                                 newOpers);
217           })
218           .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
219             wrap = rewriteCallComplexResultType(cmplx, newResTys, newInTys,
220                                                 newOpers);
221           })
222           .Default([&](mlir::Type ty) { newResTys.push_back(ty); });
223     } else if (fnTy.getResults().size() > 1) {
224       TODO(loc, "multiple results not supported yet");
225     }
226 
227     llvm::SmallVector<mlir::Type> trailingInTys;
228     llvm::SmallVector<mlir::Value> trailingOpers;
229     for (auto e : llvm::enumerate(
230              llvm::zip(fnTy.getInputs().drop_front(dropFront),
231                        callOp.getOperands().drop_front(dropFront)))) {
232       mlir::Type ty = std::get<0>(e.value());
233       mlir::Value oper = std::get<1>(e.value());
234       unsigned index = e.index();
235       llvm::TypeSwitch<mlir::Type>(ty)
236           .template Case<BoxCharType>([&](BoxCharType boxTy) {
237             bool sret;
238             if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
239               sret = callOp.callee() &&
240                      functionArgIsSRet(index,
241                                        getModule().lookupSymbol<mlir::FuncOp>(
242                                            *callOp.callee()));
243             } else {
244               // TODO: dispatch case; how do we put arguments on a call?
245               // We cannot put both an sret and the dispatch object first.
246               sret = false;
247               TODO(loc, "dispatch + sret not supported yet");
248             }
249             auto m = specifics->boxcharArgumentType(boxTy.getEleTy(), sret);
250             auto unbox =
251                 rewriter->create<UnboxCharOp>(loc, std::get<mlir::Type>(m[0]),
252                                               std::get<mlir::Type>(m[1]), oper);
253             // unboxed CHARACTER arguments
254             for (auto e : llvm::enumerate(m)) {
255               unsigned idx = e.index();
256               auto attr = std::get<CodeGenSpecifics::Attributes>(e.value());
257               auto argTy = std::get<mlir::Type>(e.value());
258               if (attr.isAppend()) {
259                 trailingInTys.push_back(argTy);
260                 trailingOpers.push_back(unbox.getResult(idx));
261               } else {
262                 newInTys.push_back(argTy);
263                 newOpers.push_back(unbox.getResult(idx));
264               }
265             }
266           })
267           .template Case<fir::ComplexType>([&](fir::ComplexType cmplx) {
268             rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers);
269           })
270           .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
271             rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers);
272           })
273           .template Case<mlir::TupleType>([&](mlir::TupleType tuple) {
274             if (factory::isCharacterProcedureTuple(tuple)) {
275               mlir::ModuleOp module = getModule();
276               if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
277                 if (callOp.callee()) {
278                   llvm::StringRef charProcAttr =
279                       fir::getCharacterProcedureDummyAttrName();
280                   // The charProcAttr attribute is only used as a safety to
281                   // confirm that this is a dummy procedure and should be split.
282                   // It cannot be used to match because attributes are not
283                   // available in case of indirect calls.
284                   auto funcOp =
285                       module.lookupSymbol<mlir::FuncOp>(*callOp.callee());
286                   if (funcOp &&
287                       !funcOp.template getArgAttrOfType<mlir::UnitAttr>(
288                           index, charProcAttr))
289                     mlir::emitError(loc, "tuple argument will be split even "
290                                          "though it does not have the `" +
291                                              charProcAttr + "` attribute");
292                 }
293               }
294               mlir::Type funcPointerType = tuple.getType(0);
295               mlir::Type lenType = tuple.getType(1);
296               FirOpBuilder builder(*rewriter, getKindMapping(module));
297               auto [funcPointer, len] =
298                   factory::extractCharacterProcedureTuple(builder, loc, oper);
299               newInTys.push_back(funcPointerType);
300               newOpers.push_back(funcPointer);
301               trailingInTys.push_back(lenType);
302               trailingOpers.push_back(len);
303             } else {
304               newInTys.push_back(tuple);
305               newOpers.push_back(oper);
306             }
307           })
308           .Default([&](mlir::Type ty) {
309             newInTys.push_back(ty);
310             newOpers.push_back(oper);
311           });
312     }
313     newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end());
314     newOpers.insert(newOpers.end(), trailingOpers.begin(), trailingOpers.end());
315     if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
316       fir::CallOp newCall;
317       if (callOp.callee().hasValue()) {
318         newCall = rewriter->create<A>(loc, callOp.callee().getValue(),
319                                       newResTys, newOpers);
320       } else {
321         // Force new type on the input operand.
322         newOpers[0].setType(mlir::FunctionType::get(
323             callOp.getContext(),
324             mlir::TypeRange{newInTys}.drop_front(dropFront), newResTys));
325         newCall = rewriter->create<A>(loc, newResTys, newOpers);
326       }
327       LLVM_DEBUG(llvm::dbgs() << "replacing call with " << newCall << '\n');
328       if (wrap.hasValue())
329         replaceOp(callOp, (*wrap)(newCall.getOperation()));
330       else
331         replaceOp(callOp, newCall.getResults());
332     } else {
333       // A is fir::DispatchOp
334       TODO(loc, "dispatch not implemented");
335     }
336   }
337 
338   // Result type fixup for fir::ComplexType and mlir::ComplexType
339   template <typename A, typename B>
340   void lowerComplexSignatureRes(A cmplx, B &newResTys, B &newInTys) {
341     if (noComplexConversion) {
342       newResTys.push_back(cmplx);
343     } else {
344       for (auto &tup : specifics->complexReturnType(cmplx.getElementType())) {
345         auto argTy = std::get<mlir::Type>(tup);
346         if (std::get<CodeGenSpecifics::Attributes>(tup).isSRet())
347           newInTys.push_back(argTy);
348         else
349           newResTys.push_back(argTy);
350       }
351     }
352   }
353 
354   // Argument type fixup for fir::ComplexType and mlir::ComplexType
355   template <typename A, typename B>
356   void lowerComplexSignatureArg(A cmplx, B &newInTys) {
357     if (noComplexConversion)
358       newInTys.push_back(cmplx);
359     else
360       for (auto &tup : specifics->complexArgumentType(cmplx.getElementType()))
361         newInTys.push_back(std::get<mlir::Type>(tup));
362   }
363 
364   /// Taking the address of a function. Modify the signature as needed.
365   void convertAddrOp(AddrOfOp addrOp) {
366     rewriter->setInsertionPoint(addrOp);
367     auto addrTy = addrOp.getType().cast<mlir::FunctionType>();
368     llvm::SmallVector<mlir::Type> newResTys;
369     llvm::SmallVector<mlir::Type> newInTys;
370     for (mlir::Type ty : addrTy.getResults()) {
371       llvm::TypeSwitch<mlir::Type>(ty)
372           .Case<fir::ComplexType>([&](fir::ComplexType ty) {
373             lowerComplexSignatureRes(ty, newResTys, newInTys);
374           })
375           .Case<mlir::ComplexType>([&](mlir::ComplexType ty) {
376             lowerComplexSignatureRes(ty, newResTys, newInTys);
377           })
378           .Default([&](mlir::Type ty) { newResTys.push_back(ty); });
379     }
380     llvm::SmallVector<mlir::Type> trailingInTys;
381     for (mlir::Type ty : addrTy.getInputs()) {
382       llvm::TypeSwitch<mlir::Type>(ty)
383           .Case<BoxCharType>([&](BoxCharType box) {
384             if (noCharacterConversion) {
385               newInTys.push_back(box);
386             } else {
387               for (auto &tup : specifics->boxcharArgumentType(box.getEleTy())) {
388                 auto attr = std::get<CodeGenSpecifics::Attributes>(tup);
389                 auto argTy = std::get<mlir::Type>(tup);
390                 llvm::SmallVector<mlir::Type> &vec =
391                     attr.isAppend() ? trailingInTys : newInTys;
392                 vec.push_back(argTy);
393               }
394             }
395           })
396           .Case<fir::ComplexType>([&](fir::ComplexType ty) {
397             lowerComplexSignatureArg(ty, newInTys);
398           })
399           .Case<mlir::ComplexType>([&](mlir::ComplexType ty) {
400             lowerComplexSignatureArg(ty, newInTys);
401           })
402           .Case<mlir::TupleType>([&](mlir::TupleType tuple) {
403             if (factory::isCharacterProcedureTuple(tuple)) {
404               newInTys.push_back(tuple.getType(0));
405               trailingInTys.push_back(tuple.getType(1));
406             } else {
407               newInTys.push_back(ty);
408             }
409           })
410           .Default([&](mlir::Type ty) { newInTys.push_back(ty); });
411     }
412     // append trailing input types
413     newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end());
414     // replace this op with a new one with the updated signature
415     auto newTy = rewriter->getFunctionType(newInTys, newResTys);
416     auto newOp =
417         rewriter->create<AddrOfOp>(addrOp.getLoc(), newTy, addrOp.symbol());
418     replaceOp(addrOp, newOp.getResult());
419   }
420 
421   /// Convert the type signatures on all the functions present in the module.
422   /// As the type signature is being changed, this must also update the
423   /// function itself to use any new arguments, etc.
424   mlir::LogicalResult convertTypes(mlir::ModuleOp mod) {
425     for (auto fn : mod.getOps<mlir::FuncOp>())
426       convertSignature(fn);
427     return mlir::success();
428   }
429 
430   /// If the signature does not need any special target-specific converions,
431   /// then it is considered portable for any target, and this function will
432   /// return `true`. Otherwise, the signature is not portable and `false` is
433   /// returned.
434   bool hasPortableSignature(mlir::Type signature) {
435     assert(signature.isa<mlir::FunctionType>());
436     auto func = signature.dyn_cast<mlir::FunctionType>();
437     for (auto ty : func.getResults())
438       if ((ty.isa<BoxCharType>() && !noCharacterConversion) ||
439           (isa_complex(ty) && !noComplexConversion)) {
440         LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n");
441         return false;
442       }
443     for (auto ty : func.getInputs())
444       if (((ty.isa<BoxCharType>() || factory::isCharacterProcedureTuple(ty)) &&
445            !noCharacterConversion) ||
446           (isa_complex(ty) && !noComplexConversion)) {
447         LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n");
448         return false;
449       }
450     return true;
451   }
452 
453   /// Rewrite the signatures and body of the `FuncOp`s in the module for
454   /// the immediately subsequent target code gen.
455   void convertSignature(mlir::FuncOp func) {
456     auto funcTy = func.getType().cast<mlir::FunctionType>();
457     if (hasPortableSignature(funcTy))
458       return;
459     llvm::SmallVector<mlir::Type> newResTys;
460     llvm::SmallVector<mlir::Type> newInTys;
461     llvm::SmallVector<FixupTy> fixups;
462 
463     // Convert return value(s)
464     for (auto ty : funcTy.getResults())
465       llvm::TypeSwitch<mlir::Type>(ty)
466           .Case<fir::ComplexType>([&](fir::ComplexType cmplx) {
467             if (noComplexConversion)
468               newResTys.push_back(cmplx);
469             else
470               doComplexReturn(func, cmplx, newResTys, newInTys, fixups);
471           })
472           .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
473             if (noComplexConversion)
474               newResTys.push_back(cmplx);
475             else
476               doComplexReturn(func, cmplx, newResTys, newInTys, fixups);
477           })
478           .Default([&](mlir::Type ty) { newResTys.push_back(ty); });
479 
480     // Convert arguments
481     llvm::SmallVector<mlir::Type> trailingTys;
482     for (auto e : llvm::enumerate(funcTy.getInputs())) {
483       auto ty = e.value();
484       unsigned index = e.index();
485       llvm::TypeSwitch<mlir::Type>(ty)
486           .Case<BoxCharType>([&](BoxCharType boxTy) {
487             if (noCharacterConversion) {
488               newInTys.push_back(boxTy);
489             } else {
490               // Convert a CHARACTER argument type. This can involve separating
491               // the pointer and the LEN into two arguments and moving the LEN
492               // argument to the end of the arg list.
493               bool sret = functionArgIsSRet(index, func);
494               for (auto e : llvm::enumerate(specifics->boxcharArgumentType(
495                        boxTy.getEleTy(), sret))) {
496                 auto &tup = e.value();
497                 auto index = e.index();
498                 auto attr = std::get<CodeGenSpecifics::Attributes>(tup);
499                 auto argTy = std::get<mlir::Type>(tup);
500                 if (attr.isAppend()) {
501                   trailingTys.push_back(argTy);
502                 } else {
503                   if (sret) {
504                     fixups.emplace_back(FixupTy::Codes::CharPair,
505                                         newInTys.size(), index);
506                   } else {
507                     fixups.emplace_back(FixupTy::Codes::Trailing,
508                                         newInTys.size(), trailingTys.size());
509                   }
510                   newInTys.push_back(argTy);
511                 }
512               }
513             }
514           })
515           .Case<fir::ComplexType>([&](fir::ComplexType cmplx) {
516             if (noComplexConversion)
517               newInTys.push_back(cmplx);
518             else
519               doComplexArg(func, cmplx, newInTys, fixups);
520           })
521           .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
522             if (noComplexConversion)
523               newInTys.push_back(cmplx);
524             else
525               doComplexArg(func, cmplx, newInTys, fixups);
526           })
527           .Case<mlir::TupleType>([&](mlir::TupleType tuple) {
528             if (factory::isCharacterProcedureTuple(tuple)) {
529               fixups.emplace_back(FixupTy::Codes::TrailingCharProc,
530                                   newInTys.size(), trailingTys.size());
531               newInTys.push_back(tuple.getType(0));
532               trailingTys.push_back(tuple.getType(1));
533             } else {
534               newInTys.push_back(ty);
535             }
536           })
537           .Default([&](mlir::Type ty) { newInTys.push_back(ty); });
538     }
539 
540     if (!func.empty()) {
541       // If the function has a body, then apply the fixups to the arguments and
542       // return ops as required. These fixups are done in place.
543       auto loc = func.getLoc();
544       const auto fixupSize = fixups.size();
545       const auto oldArgTys = func.getType().getInputs();
546       int offset = 0;
547       for (std::remove_const_t<decltype(fixupSize)> i = 0; i < fixupSize; ++i) {
548         const auto &fixup = fixups[i];
549         switch (fixup.code) {
550         case FixupTy::Codes::ArgumentAsLoad: {
551           // Argument was pass-by-value, but is now pass-by-reference and
552           // possibly with a different element type.
553           auto newArg = func.front().insertArgument(fixup.index,
554                                                     newInTys[fixup.index], loc);
555           rewriter->setInsertionPointToStart(&func.front());
556           auto oldArgTy = ReferenceType::get(oldArgTys[fixup.index - offset]);
557           auto cast = rewriter->create<ConvertOp>(loc, oldArgTy, newArg);
558           auto load = rewriter->create<fir::LoadOp>(loc, cast);
559           func.getArgument(fixup.index + 1).replaceAllUsesWith(load);
560           func.front().eraseArgument(fixup.index + 1);
561         } break;
562         case FixupTy::Codes::ArgumentType: {
563           // Argument is pass-by-value, but its type has likely been modified to
564           // suit the target ABI convention.
565           auto newArg = func.front().insertArgument(fixup.index,
566                                                     newInTys[fixup.index], loc);
567           rewriter->setInsertionPointToStart(&func.front());
568           auto mem =
569               rewriter->create<fir::AllocaOp>(loc, newInTys[fixup.index]);
570           rewriter->create<fir::StoreOp>(loc, newArg, mem);
571           auto oldArgTy = ReferenceType::get(oldArgTys[fixup.index - offset]);
572           auto cast = rewriter->create<ConvertOp>(loc, oldArgTy, mem);
573           mlir::Value load = rewriter->create<fir::LoadOp>(loc, cast);
574           func.getArgument(fixup.index + 1).replaceAllUsesWith(load);
575           func.front().eraseArgument(fixup.index + 1);
576           LLVM_DEBUG(llvm::dbgs()
577                      << "old argument: " << oldArgTy.getEleTy()
578                      << ", repl: " << load << ", new argument: "
579                      << func.getArgument(fixup.index).getType() << '\n');
580         } break;
581         case FixupTy::Codes::CharPair: {
582           // The FIR boxchar argument has been split into a pair of distinct
583           // arguments that are in juxtaposition to each other.
584           auto newArg = func.front().insertArgument(fixup.index,
585                                                     newInTys[fixup.index], loc);
586           if (fixup.second == 1) {
587             rewriter->setInsertionPointToStart(&func.front());
588             auto boxTy = oldArgTys[fixup.index - offset - fixup.second];
589             auto box = rewriter->create<EmboxCharOp>(
590                 loc, boxTy, func.front().getArgument(fixup.index - 1), newArg);
591             func.getArgument(fixup.index + 1).replaceAllUsesWith(box);
592             func.front().eraseArgument(fixup.index + 1);
593             offset++;
594           }
595         } break;
596         case FixupTy::Codes::ReturnAsStore: {
597           // The value being returned is now being returned in memory (callee
598           // stack space) through a hidden reference argument.
599           auto newArg = func.front().insertArgument(fixup.index,
600                                                     newInTys[fixup.index], loc);
601           offset++;
602           func.walk([&](mlir::ReturnOp ret) {
603             rewriter->setInsertionPoint(ret);
604             auto oldOper = ret.getOperand(0);
605             auto oldOperTy = ReferenceType::get(oldOper.getType());
606             auto cast = rewriter->create<ConvertOp>(loc, oldOperTy, newArg);
607             rewriter->create<fir::StoreOp>(loc, oldOper, cast);
608             rewriter->create<mlir::ReturnOp>(loc);
609             ret.erase();
610           });
611         } break;
612         case FixupTy::Codes::ReturnType: {
613           // The function is still returning a value, but its type has likely
614           // changed to suit the target ABI convention.
615           func.walk([&](mlir::ReturnOp ret) {
616             rewriter->setInsertionPoint(ret);
617             auto oldOper = ret.getOperand(0);
618             auto oldOperTy = ReferenceType::get(oldOper.getType());
619             auto mem =
620                 rewriter->create<fir::AllocaOp>(loc, newResTys[fixup.index]);
621             auto cast = rewriter->create<ConvertOp>(loc, oldOperTy, mem);
622             rewriter->create<fir::StoreOp>(loc, oldOper, cast);
623             mlir::Value load = rewriter->create<fir::LoadOp>(loc, mem);
624             rewriter->create<mlir::ReturnOp>(loc, load);
625             ret.erase();
626           });
627         } break;
628         case FixupTy::Codes::Split: {
629           // The FIR argument has been split into a pair of distinct arguments
630           // that are in juxtaposition to each other. (For COMPLEX value.)
631           auto newArg = func.front().insertArgument(fixup.index,
632                                                     newInTys[fixup.index], loc);
633           if (fixup.second == 1) {
634             rewriter->setInsertionPointToStart(&func.front());
635             auto cplxTy = oldArgTys[fixup.index - offset - fixup.second];
636             auto undef = rewriter->create<UndefOp>(loc, cplxTy);
637             auto iTy = rewriter->getIntegerType(32);
638             auto zero = rewriter->getIntegerAttr(iTy, 0);
639             auto one = rewriter->getIntegerAttr(iTy, 1);
640             auto cplx1 = rewriter->create<InsertValueOp>(
641                 loc, cplxTy, undef, func.front().getArgument(fixup.index - 1),
642                 rewriter->getArrayAttr(zero));
643             auto cplx = rewriter->create<InsertValueOp>(
644                 loc, cplxTy, cplx1, newArg, rewriter->getArrayAttr(one));
645             func.getArgument(fixup.index + 1).replaceAllUsesWith(cplx);
646             func.front().eraseArgument(fixup.index + 1);
647             offset++;
648           }
649         } break;
650         case FixupTy::Codes::Trailing: {
651           // The FIR argument has been split into a pair of distinct arguments.
652           // The first part of the pair appears in the original argument
653           // position. The second part of the pair is appended after all the
654           // original arguments. (Boxchar arguments.)
655           auto newBufArg = func.front().insertArgument(
656               fixup.index, newInTys[fixup.index], loc);
657           auto newLenArg =
658               func.front().addArgument(trailingTys[fixup.second], loc);
659           auto boxTy = oldArgTys[fixup.index - offset];
660           rewriter->setInsertionPointToStart(&func.front());
661           auto box =
662               rewriter->create<EmboxCharOp>(loc, boxTy, newBufArg, newLenArg);
663           func.getArgument(fixup.index + 1).replaceAllUsesWith(box);
664           func.front().eraseArgument(fixup.index + 1);
665         } break;
666         case FixupTy::Codes::TrailingCharProc: {
667           // The FIR character procedure argument tuple has been split into a
668           // pair of distinct arguments. The first part of the pair appears in
669           // the original argument position. The second part of the pair is
670           // appended after all the original arguments.
671           auto newProcPointerArg = func.front().insertArgument(
672               fixup.index, newInTys[fixup.index], loc);
673           auto newLenArg =
674               func.front().addArgument(trailingTys[fixup.second], loc);
675           auto tupleType = oldArgTys[fixup.index - offset];
676           rewriter->setInsertionPointToStart(&func.front());
677           FirOpBuilder builder(*rewriter, getKindMapping(getModule()));
678           auto tuple = factory::createCharacterProcedureTuple(
679               builder, loc, tupleType, newProcPointerArg, newLenArg);
680           func.getArgument(fixup.index + 1).replaceAllUsesWith(tuple);
681           func.front().eraseArgument(fixup.index + 1);
682         } break;
683         }
684       }
685     }
686 
687     // Set the new type and finalize the arguments, etc.
688     newInTys.insert(newInTys.end(), trailingTys.begin(), trailingTys.end());
689     auto newFuncTy =
690         mlir::FunctionType::get(func.getContext(), newInTys, newResTys);
691     LLVM_DEBUG(llvm::dbgs() << "new func: " << newFuncTy << '\n');
692     func.setType(newFuncTy);
693 
694     for (auto &fixup : fixups)
695       if (fixup.finalizer)
696         (*fixup.finalizer)(func);
697   }
698 
699   inline bool functionArgIsSRet(unsigned index, mlir::FuncOp func) {
700     if (auto attr = func.getArgAttrOfType<mlir::UnitAttr>(index, "llvm.sret"))
701       return true;
702     return false;
703   }
704 
705   /// Convert a complex return value. This can involve converting the return
706   /// value to a "hidden" first argument or packing the complex into a wide
707   /// GPR.
708   template <typename A, typename B, typename C>
709   void doComplexReturn(mlir::FuncOp func, A cmplx, B &newResTys, B &newInTys,
710                        C &fixups) {
711     if (noComplexConversion) {
712       newResTys.push_back(cmplx);
713       return;
714     }
715     auto m = specifics->complexReturnType(cmplx.getElementType());
716     assert(m.size() == 1);
717     auto &tup = m[0];
718     auto attr = std::get<CodeGenSpecifics::Attributes>(tup);
719     auto argTy = std::get<mlir::Type>(tup);
720     if (attr.isSRet()) {
721       unsigned argNo = newInTys.size();
722       fixups.emplace_back(
723           FixupTy::Codes::ReturnAsStore, argNo, [=](mlir::FuncOp func) {
724             func.setArgAttr(argNo, "llvm.sret", rewriter->getUnitAttr());
725           });
726       newInTys.push_back(argTy);
727       return;
728     }
729     fixups.emplace_back(FixupTy::Codes::ReturnType, newResTys.size());
730     newResTys.push_back(argTy);
731   }
732 
733   /// Convert a complex argument value. This can involve storing the value to
734   /// a temporary memory location or factoring the value into two distinct
735   /// arguments.
736   template <typename A, typename B, typename C>
737   void doComplexArg(mlir::FuncOp func, A cmplx, B &newInTys, C &fixups) {
738     if (noComplexConversion) {
739       newInTys.push_back(cmplx);
740       return;
741     }
742     auto m = specifics->complexArgumentType(cmplx.getElementType());
743     const auto fixupCode =
744         m.size() > 1 ? FixupTy::Codes::Split : FixupTy::Codes::ArgumentType;
745     for (auto e : llvm::enumerate(m)) {
746       auto &tup = e.value();
747       auto index = e.index();
748       auto attr = std::get<CodeGenSpecifics::Attributes>(tup);
749       auto argTy = std::get<mlir::Type>(tup);
750       auto argNo = newInTys.size();
751       if (attr.isByVal()) {
752         if (auto align = attr.getAlignment())
753           fixups.emplace_back(
754               FixupTy::Codes::ArgumentAsLoad, argNo, [=](mlir::FuncOp func) {
755                 func.setArgAttr(argNo, "llvm.byval", rewriter->getUnitAttr());
756                 func.setArgAttr(argNo, "llvm.align",
757                                 rewriter->getIntegerAttr(
758                                     rewriter->getIntegerType(32), align));
759               });
760         else
761           fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad, newInTys.size(),
762                               [=](mlir::FuncOp func) {
763                                 func.setArgAttr(argNo, "llvm.byval",
764                                                 rewriter->getUnitAttr());
765                               });
766       } else {
767         if (auto align = attr.getAlignment())
768           fixups.emplace_back(fixupCode, argNo, index, [=](mlir::FuncOp func) {
769             func.setArgAttr(
770                 argNo, "llvm.align",
771                 rewriter->getIntegerAttr(rewriter->getIntegerType(32), align));
772           });
773         else
774           fixups.emplace_back(fixupCode, argNo, index);
775       }
776       newInTys.push_back(argTy);
777     }
778   }
779 
780 private:
781   // Replace `op` and remove it.
782   void replaceOp(mlir::Operation *op, mlir::ValueRange newValues) {
783     op->replaceAllUsesWith(newValues);
784     op->dropAllReferences();
785     op->erase();
786   }
787 
788   inline void setMembers(CodeGenSpecifics *s, mlir::OpBuilder *r) {
789     specifics = s;
790     rewriter = r;
791   }
792 
793   inline void clearMembers() { setMembers(nullptr, nullptr); }
794 
795   CodeGenSpecifics *specifics{};
796   mlir::OpBuilder *rewriter;
797 }; // namespace
798 } // namespace
799 
800 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
801 fir::createFirTargetRewritePass(const TargetRewriteOptions &options) {
802   return std::make_unique<TargetRewrite>(options);
803 }
804