1 //===-- TargetRewrite.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Target rewrite: rewriting of ops to make target-specific lowerings manifest.
10 // LLVM expects different lowering idioms to be used for distinct target
11 // triples. These distinctions are handled by this pass.
12 //
13 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "PassDetail.h"
18 #include "Target.h"
19 #include "flang/Lower/Todo.h"
20 #include "flang/Optimizer/CodeGen/CodeGen.h"
21 #include "flang/Optimizer/Dialect/FIRDialect.h"
22 #include "flang/Optimizer/Dialect/FIROps.h"
23 #include "flang/Optimizer/Dialect/FIRType.h"
24 #include "flang/Optimizer/Support/FIRContext.h"
25 #include "mlir/Transforms/DialectConversion.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Support/Debug.h"
29 
30 using namespace fir;
31 
32 #define DEBUG_TYPE "flang-target-rewrite"
33 
34 namespace {
35 
36 /// Fixups for updating a FuncOp's arguments and return values.
37 struct FixupTy {
38   enum class Codes { CharPair, Trailing };
39 
40   FixupTy(Codes code, std::size_t index, std::size_t second = 0)
41       : code{code}, index{index}, second{second} {}
42   FixupTy(Codes code, std::size_t index,
43           std::function<void(mlir::FuncOp)> &&finalizer)
44       : code{code}, index{index}, finalizer{finalizer} {}
45   FixupTy(Codes code, std::size_t index, std::size_t second,
46           std::function<void(mlir::FuncOp)> &&finalizer)
47       : code{code}, index{index}, second{second}, finalizer{finalizer} {}
48 
49   Codes code;
50   std::size_t index;
51   std::size_t second{};
52   llvm::Optional<std::function<void(mlir::FuncOp)>> finalizer{};
53 }; // namespace
54 
55 /// Target-specific rewriting of the FIR. This is a prerequisite pass to code
56 /// generation that traverses the FIR and modifies types and operations to a
57 /// form that is appropriate for the specific target. LLVM IR has specific
58 /// idioms that are used for distinct target processor and ABI combinations.
59 class TargetRewrite : public TargetRewriteBase<TargetRewrite> {
60 public:
61   TargetRewrite(const TargetRewriteOptions &options) {
62     noCharacterConversion = options.noCharacterConversion;
63   }
64 
65   void runOnOperation() override final {
66     auto &context = getContext();
67     mlir::OpBuilder rewriter(&context);
68 
69     auto mod = getModule();
70     if (!forcedTargetTriple.empty()) {
71       setTargetTriple(mod, forcedTargetTriple);
72     }
73 
74     auto specifics = CodeGenSpecifics::get(getOperation().getContext(),
75                                            getTargetTriple(getOperation()),
76                                            getKindMapping(getOperation()));
77     setMembers(specifics.get(), &rewriter);
78 
79     // Perform type conversion on signatures and call sites.
80     if (mlir::failed(convertTypes(mod))) {
81       mlir::emitError(mlir::UnknownLoc::get(&context),
82                       "error in converting types to target abi");
83       signalPassFailure();
84     }
85 
86     // Convert ops in target-specific patterns.
87     mod.walk([&](mlir::Operation *op) {
88       if (auto call = dyn_cast<fir::CallOp>(op)) {
89         if (!hasPortableSignature(call.getFunctionType()))
90           convertCallOp(call);
91       } else if (auto dispatch = dyn_cast<DispatchOp>(op)) {
92         if (!hasPortableSignature(dispatch.getFunctionType()))
93           convertCallOp(dispatch);
94       }
95     });
96 
97     clearMembers();
98   }
99 
100   mlir::ModuleOp getModule() { return getOperation(); }
101 
102   // Convert fir.call and fir.dispatch Ops.
103   template <typename A>
104   void convertCallOp(A callOp) {
105     auto fnTy = callOp.getFunctionType();
106     auto loc = callOp.getLoc();
107     rewriter->setInsertionPoint(callOp);
108     llvm::SmallVector<mlir::Type> newResTys;
109     llvm::SmallVector<mlir::Type> newInTys;
110     llvm::SmallVector<mlir::Value> newOpers;
111 
112     // If the call is indirect, the first argument must still be the function
113     // to call.
114     int dropFront = 0;
115     if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
116       if (!callOp.callee().hasValue()) {
117         newInTys.push_back(fnTy.getInput(0));
118         newOpers.push_back(callOp.getOperand(0));
119         dropFront = 1;
120       }
121     }
122 
123     // Determine the rewrite function, `wrap`, for the result value.
124     llvm::Optional<std::function<mlir::Value(mlir::Operation *)>> wrap;
125     if (fnTy.getResults().size() == 1) {
126       mlir::Type ty = fnTy.getResult(0);
127       newResTys.push_back(ty);
128     } else if (fnTy.getResults().size() > 1) {
129       TODO(loc, "multiple results not supported yet");
130     }
131 
132     llvm::SmallVector<mlir::Type> trailingInTys;
133     llvm::SmallVector<mlir::Value> trailingOpers;
134     for (auto e : llvm::enumerate(
135              llvm::zip(fnTy.getInputs().drop_front(dropFront),
136                        callOp.getOperands().drop_front(dropFront)))) {
137       mlir::Type ty = std::get<0>(e.value());
138       mlir::Value oper = std::get<1>(e.value());
139       unsigned index = e.index();
140       llvm::TypeSwitch<mlir::Type>(ty)
141           .template Case<BoxCharType>([&](BoxCharType boxTy) {
142             bool sret;
143             if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
144               sret = callOp.callee() &&
145                      functionArgIsSRet(index,
146                                        getModule().lookupSymbol<mlir::FuncOp>(
147                                            *callOp.callee()));
148             } else {
149               // TODO: dispatch case; how do we put arguments on a call?
150               // We cannot put both an sret and the dispatch object first.
151               sret = false;
152               TODO(loc, "dispatch + sret not supported yet");
153             }
154             auto m = specifics->boxcharArgumentType(boxTy.getEleTy(), sret);
155             auto unbox =
156                 rewriter->create<UnboxCharOp>(loc, std::get<mlir::Type>(m[0]),
157                                               std::get<mlir::Type>(m[1]), oper);
158             // unboxed CHARACTER arguments
159             for (auto e : llvm::enumerate(m)) {
160               unsigned idx = e.index();
161               auto attr = std::get<CodeGenSpecifics::Attributes>(e.value());
162               auto argTy = std::get<mlir::Type>(e.value());
163               if (attr.isAppend()) {
164                 trailingInTys.push_back(argTy);
165                 trailingOpers.push_back(unbox.getResult(idx));
166               } else {
167                 newInTys.push_back(argTy);
168                 newOpers.push_back(unbox.getResult(idx));
169               }
170             }
171           })
172           .Default([&](mlir::Type ty) {
173             newInTys.push_back(ty);
174             newOpers.push_back(oper);
175           });
176     }
177     newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end());
178     newOpers.insert(newOpers.end(), trailingOpers.begin(), trailingOpers.end());
179     if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
180       fir::CallOp newCall;
181       if (callOp.callee().hasValue()) {
182         newCall = rewriter->create<A>(loc, callOp.callee().getValue(),
183                                       newResTys, newOpers);
184       } else {
185         // Force new type on the input operand.
186         newOpers[0].setType(mlir::FunctionType::get(
187             callOp.getContext(),
188             mlir::TypeRange{newInTys}.drop_front(dropFront), newResTys));
189         newCall = rewriter->create<A>(loc, newResTys, newOpers);
190       }
191       LLVM_DEBUG(llvm::dbgs() << "replacing call with " << newCall << '\n');
192       if (wrap.hasValue())
193         replaceOp(callOp, (*wrap)(newCall.getOperation()));
194       else
195         replaceOp(callOp, newCall.getResults());
196     } else {
197       // A is fir::DispatchOp
198       TODO(loc, "dispatch not implemented");
199     }
200   }
201   /// Convert the type signatures on all the functions present in the module.
202   /// As the type signature is being changed, this must also update the
203   /// function itself to use any new arguments, etc.
204   mlir::LogicalResult convertTypes(mlir::ModuleOp mod) {
205     for (auto fn : mod.getOps<mlir::FuncOp>())
206       convertSignature(fn);
207     return mlir::success();
208   }
209 
210   /// If the signature does not need any special target-specific converions,
211   /// then it is considered portable for any target, and this function will
212   /// return `true`. Otherwise, the signature is not portable and `false` is
213   /// returned.
214   bool hasPortableSignature(mlir::Type signature) {
215     assert(signature.isa<mlir::FunctionType>());
216     auto func = signature.dyn_cast<mlir::FunctionType>();
217     for (auto ty : func.getResults())
218       if ((ty.isa<BoxCharType>() && !noCharacterConversion)) {
219         LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n");
220         return false;
221       }
222     for (auto ty : func.getInputs())
223       if ((ty.isa<BoxCharType>() && !noCharacterConversion)) {
224         LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n");
225         return false;
226       }
227     return true;
228   }
229 
230   /// Rewrite the signatures and body of the `FuncOp`s in the module for
231   /// the immediately subsequent target code gen.
232   void convertSignature(mlir::FuncOp func) {
233     auto funcTy = func.getType().cast<mlir::FunctionType>();
234     if (hasPortableSignature(funcTy))
235       return;
236     llvm::SmallVector<mlir::Type> newResTys;
237     llvm::SmallVector<mlir::Type> newInTys;
238     llvm::SmallVector<FixupTy> fixups;
239 
240     // Convert return value(s)
241     for (auto ty : funcTy.getResults())
242       newResTys.push_back(ty);
243 
244     // Convert arguments
245     llvm::SmallVector<mlir::Type> trailingTys;
246     for (auto e : llvm::enumerate(funcTy.getInputs())) {
247       auto ty = e.value();
248       unsigned index = e.index();
249       llvm::TypeSwitch<mlir::Type>(ty)
250           .Case<BoxCharType>([&](BoxCharType boxTy) {
251             if (noCharacterConversion) {
252               newInTys.push_back(boxTy);
253             } else {
254               // Convert a CHARACTER argument type. This can involve separating
255               // the pointer and the LEN into two arguments and moving the LEN
256               // argument to the end of the arg list.
257               bool sret = functionArgIsSRet(index, func);
258               for (auto e : llvm::enumerate(specifics->boxcharArgumentType(
259                        boxTy.getEleTy(), sret))) {
260                 auto &tup = e.value();
261                 auto index = e.index();
262                 auto attr = std::get<CodeGenSpecifics::Attributes>(tup);
263                 auto argTy = std::get<mlir::Type>(tup);
264                 if (attr.isAppend()) {
265                   trailingTys.push_back(argTy);
266                 } else {
267                   if (sret) {
268                     fixups.emplace_back(FixupTy::Codes::CharPair,
269                                         newInTys.size(), index);
270                   } else {
271                     fixups.emplace_back(FixupTy::Codes::Trailing,
272                                         newInTys.size(), trailingTys.size());
273                   }
274                   newInTys.push_back(argTy);
275                 }
276               }
277             }
278           })
279           .Default([&](mlir::Type ty) { newInTys.push_back(ty); });
280     }
281 
282     if (!func.empty()) {
283       // If the function has a body, then apply the fixups to the arguments and
284       // return ops as required. These fixups are done in place.
285       auto loc = func.getLoc();
286       const auto fixupSize = fixups.size();
287       const auto oldArgTys = func.getType().getInputs();
288       int offset = 0;
289       for (std::remove_const_t<decltype(fixupSize)> i = 0; i < fixupSize; ++i) {
290         const auto &fixup = fixups[i];
291         switch (fixup.code) {
292         case FixupTy::Codes::CharPair: {
293           // The FIR boxchar argument has been split into a pair of distinct
294           // arguments that are in juxtaposition to each other.
295           auto newArg =
296               func.front().insertArgument(fixup.index, newInTys[fixup.index]);
297           if (fixup.second == 1) {
298             rewriter->setInsertionPointToStart(&func.front());
299             auto boxTy = oldArgTys[fixup.index - offset - fixup.second];
300             auto box = rewriter->create<EmboxCharOp>(
301                 loc, boxTy, func.front().getArgument(fixup.index - 1), newArg);
302             func.getArgument(fixup.index + 1).replaceAllUsesWith(box);
303             func.front().eraseArgument(fixup.index + 1);
304             offset++;
305           }
306         } break;
307         case FixupTy::Codes::Trailing: {
308           // The FIR argument has been split into a pair of distinct arguments.
309           // The first part of the pair appears in the original argument
310           // position. The second part of the pair is appended after all the
311           // original arguments. (Boxchar arguments.)
312           auto newBufArg =
313               func.front().insertArgument(fixup.index, newInTys[fixup.index]);
314           auto newLenArg = func.front().addArgument(trailingTys[fixup.second]);
315           auto boxTy = oldArgTys[fixup.index - offset];
316           rewriter->setInsertionPointToStart(&func.front());
317           auto box =
318               rewriter->create<EmboxCharOp>(loc, boxTy, newBufArg, newLenArg);
319           func.getArgument(fixup.index + 1).replaceAllUsesWith(box);
320           func.front().eraseArgument(fixup.index + 1);
321         } break;
322         }
323       }
324     }
325 
326     // Set the new type and finalize the arguments, etc.
327     newInTys.insert(newInTys.end(), trailingTys.begin(), trailingTys.end());
328     auto newFuncTy =
329         mlir::FunctionType::get(func.getContext(), newInTys, newResTys);
330     LLVM_DEBUG(llvm::dbgs() << "new func: " << newFuncTy << '\n');
331     func.setType(newFuncTy);
332 
333     for (auto &fixup : fixups)
334       if (fixup.finalizer)
335         (*fixup.finalizer)(func);
336   }
337 
338   inline bool functionArgIsSRet(unsigned index, mlir::FuncOp func) {
339     if (auto attr = func.getArgAttrOfType<mlir::UnitAttr>(index, "llvm.sret"))
340       return true;
341     return false;
342   }
343 
344 private:
345   // Replace `op` and remove it.
346   void replaceOp(mlir::Operation *op, mlir::ValueRange newValues) {
347     op->replaceAllUsesWith(newValues);
348     op->dropAllReferences();
349     op->erase();
350   }
351 
352   inline void setMembers(CodeGenSpecifics *s, mlir::OpBuilder *r) {
353     specifics = s;
354     rewriter = r;
355   }
356 
357   inline void clearMembers() { setMembers(nullptr, nullptr); }
358 
359   CodeGenSpecifics *specifics{};
360   mlir::OpBuilder *rewriter;
361 }; // namespace
362 } // namespace
363 
364 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
365 fir::createFirTargetRewritePass(const TargetRewriteOptions &options) {
366   return std::make_unique<TargetRewrite>(options);
367 }
368