1 //===-- Bridge.cpp -- bridge to lower to MLIR -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "flang/Lower/Bridge.h"
14 #include "flang/Evaluate/tools.h"
15 #include "flang/Lower/CallInterface.h"
16 #include "flang/Lower/Mangler.h"
17 #include "flang/Lower/PFTBuilder.h"
18 #include "flang/Lower/SymbolMap.h"
19 #include "flang/Lower/Todo.h"
20 #include "flang/Optimizer/Support/FIRContext.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/Transforms/RegionUtils.h"
23 #include "llvm/Support/CommandLine.h"
24 #include "llvm/Support/Debug.h"
25 
26 #define DEBUG_TYPE "flang-lower-bridge"
27 
28 static llvm::cl::opt<bool> dumpBeforeFir(
29     "fdebug-dump-pre-fir", llvm::cl::init(false),
30     llvm::cl::desc("dump the Pre-FIR tree prior to FIR generation"));
31 
32 //===----------------------------------------------------------------------===//
33 // FirConverter
34 //===----------------------------------------------------------------------===//
35 
36 namespace {
37 
38 /// Traverse the pre-FIR tree (PFT) to generate the FIR dialect of MLIR.
39 class FirConverter : public Fortran::lower::AbstractConverter {
40 public:
41   explicit FirConverter(Fortran::lower::LoweringBridge &bridge)
42       : bridge{bridge}, foldingContext{bridge.createFoldingContext()} {}
43   virtual ~FirConverter() = default;
44 
45   /// Convert the PFT to FIR.
46   void run(Fortran::lower::pft::Program &pft) {
47     // Primary translation pass.
48     for (Fortran::lower::pft::Program::Units &u : pft.getUnits()) {
49       std::visit(
50           Fortran::common::visitors{
51               [&](Fortran::lower::pft::FunctionLikeUnit &f) { lowerFunc(f); },
52               [&](Fortran::lower::pft::ModuleLikeUnit &m) {},
53               [&](Fortran::lower::pft::BlockDataUnit &b) {},
54               [&](Fortran::lower::pft::CompilerDirectiveUnit &d) {
55                 setCurrentPosition(
56                     d.get<Fortran::parser::CompilerDirective>().source);
57                 mlir::emitWarning(toLocation(),
58                                   "ignoring all compiler directives");
59               },
60           },
61           u);
62     }
63   }
64 
65   //===--------------------------------------------------------------------===//
66   // AbstractConverter overrides
67   //===--------------------------------------------------------------------===//
68 
69   mlir::Value getSymbolAddress(Fortran::lower::SymbolRef sym) override final {
70     return lookupSymbol(sym).getAddr();
71   }
72 
73   mlir::Value genExprAddr(const Fortran::lower::SomeExpr &expr,
74                           mlir::Location *loc = nullptr) override final {
75     TODO_NOLOC("Not implemented. Needed for more complex expression lowering");
76   }
77   mlir::Value genExprValue(const Fortran::lower::SomeExpr &expr,
78                            mlir::Location *loc = nullptr) override final {
79     TODO_NOLOC("Not implemented. Needed for more complex expression lowering");
80   }
81 
82   Fortran::evaluate::FoldingContext &getFoldingContext() override final {
83     return foldingContext;
84   }
85 
86   mlir::Type genType(const Fortran::evaluate::DataRef &) override final {
87     TODO_NOLOC("Not implemented. Needed for more complex expression lowering");
88   }
89   mlir::Type genType(const Fortran::lower::SomeExpr &) override final {
90     TODO_NOLOC("Not implemented. Needed for more complex expression lowering");
91   }
92   mlir::Type genType(Fortran::lower::SymbolRef) override final {
93     TODO_NOLOC("Not implemented. Needed for more complex expression lowering");
94   }
95   mlir::Type genType(Fortran::common::TypeCategory tc) override final {
96     TODO_NOLOC("Not implemented. Needed for more complex expression lowering");
97   }
98   mlir::Type genType(Fortran::common::TypeCategory tc,
99                      int kind) override final {
100     TODO_NOLOC("Not implemented. Needed for more complex expression lowering");
101   }
102   mlir::Type genType(const Fortran::lower::pft::Variable &) override final {
103     TODO_NOLOC("Not implemented. Needed for more complex expression lowering");
104   }
105 
106   void setCurrentPosition(const Fortran::parser::CharBlock &position) {
107     if (position != Fortran::parser::CharBlock{})
108       currentPosition = position;
109   }
110 
111   //===--------------------------------------------------------------------===//
112   // Utility methods
113   //===--------------------------------------------------------------------===//
114 
115   /// Convert a parser CharBlock to a Location
116   mlir::Location toLocation(const Fortran::parser::CharBlock &cb) {
117     return genLocation(cb);
118   }
119 
120   mlir::Location toLocation() { return toLocation(currentPosition); }
121   void setCurrentEval(Fortran::lower::pft::Evaluation &eval) {
122     evalPtr = &eval;
123   }
124 
125   mlir::Location getCurrentLocation() override final { return toLocation(); }
126 
127   /// Generate a dummy location.
128   mlir::Location genUnknownLocation() override final {
129     // Note: builder may not be instantiated yet
130     return mlir::UnknownLoc::get(&getMLIRContext());
131   }
132 
133   /// Generate a `Location` from the `CharBlock`.
134   mlir::Location
135   genLocation(const Fortran::parser::CharBlock &block) override final {
136     if (const Fortran::parser::AllCookedSources *cooked =
137             bridge.getCookedSource()) {
138       if (std::optional<std::pair<Fortran::parser::SourcePosition,
139                                   Fortran::parser::SourcePosition>>
140               loc = cooked->GetSourcePositionRange(block)) {
141         // loc is a pair (begin, end); use the beginning position
142         Fortran::parser::SourcePosition &filePos = loc->first;
143         return mlir::FileLineColLoc::get(&getMLIRContext(), filePos.file.path(),
144                                          filePos.line, filePos.column);
145       }
146     }
147     return genUnknownLocation();
148   }
149 
150   fir::FirOpBuilder &getFirOpBuilder() override final { return *builder; }
151 
152   mlir::ModuleOp &getModuleOp() override final { return bridge.getModule(); }
153 
154   mlir::MLIRContext &getMLIRContext() override final {
155     return bridge.getMLIRContext();
156   }
157   std::string
158   mangleName(const Fortran::semantics::Symbol &symbol) override final {
159     return Fortran::lower::mangle::mangleName(symbol);
160   }
161 
162   const fir::KindMapping &getKindMap() override final {
163     return bridge.getKindMap();
164   }
165 
166   /// Return the predicate: "current block does not have a terminator branch".
167   bool blockIsUnterminated() {
168     mlir::Block *currentBlock = builder->getBlock();
169     return currentBlock->empty() ||
170            !currentBlock->back().hasTrait<mlir::OpTrait::IsTerminator>();
171   }
172 
173   /// Emit return and cleanup after the function has been translated.
174   void endNewFunction(Fortran::lower::pft::FunctionLikeUnit &funit) {
175     setCurrentPosition(Fortran::lower::pft::stmtSourceLoc(funit.endStmt));
176     if (funit.isMainProgram())
177       genExitRoutine();
178     funit.finalBlock = nullptr;
179     LLVM_DEBUG(llvm::dbgs() << "*** Lowering result:\n\n"
180                             << *builder->getFunction() << '\n');
181     delete builder;
182     builder = nullptr;
183     localSymbols.clear();
184   }
185 
186   /// Prepare to translate a new function
187   void startNewFunction(Fortran::lower::pft::FunctionLikeUnit &funit) {
188     assert(!builder && "expected nullptr");
189     Fortran::lower::CalleeInterface callee(funit, *this);
190     mlir::FuncOp func = callee.addEntryBlockAndMapArguments();
191     func.setVisibility(mlir::SymbolTable::Visibility::Public);
192     builder = new fir::FirOpBuilder(func, bridge.getKindMap());
193     assert(builder && "FirOpBuilder did not instantiate");
194     builder->setInsertionPointToStart(&func.front());
195   }
196 
197   /// Lower a procedure (nest).
198   void lowerFunc(Fortran::lower::pft::FunctionLikeUnit &funit) {
199     setCurrentPosition(funit.getStartingSourceLoc());
200     for (int entryIndex = 0, last = funit.entryPointList.size();
201          entryIndex < last; ++entryIndex) {
202       funit.setActiveEntry(entryIndex);
203       startNewFunction(funit); // the entry point for lowering this procedure
204       endNewFunction(funit);
205     }
206     funit.setActiveEntry(0);
207     for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions)
208       lowerFunc(f); // internal procedure
209   }
210 
211 private:
212   FirConverter() = delete;
213   FirConverter(const FirConverter &) = delete;
214   FirConverter &operator=(const FirConverter &) = delete;
215 
216   //===--------------------------------------------------------------------===//
217   // Helper member functions
218   //===--------------------------------------------------------------------===//
219 
220   /// Find the symbol in the local map or return null.
221   Fortran::lower::SymbolBox
222   lookupSymbol(const Fortran::semantics::Symbol &sym) {
223     if (Fortran::lower::SymbolBox v = localSymbols.lookupSymbol(sym))
224       return v;
225     return {};
226   }
227 
228   //===--------------------------------------------------------------------===//
229   // Termination of symbolically referenced execution units
230   //===--------------------------------------------------------------------===//
231 
232   /// END of program
233   ///
234   /// Generate the cleanup block before the program exits
235   void genExitRoutine() {
236     if (blockIsUnterminated())
237       builder->create<mlir::ReturnOp>(toLocation());
238   }
239   void genFIR(const Fortran::parser::EndProgramStmt &) { genExitRoutine(); }
240 
241   //===--------------------------------------------------------------------===//
242 
243   Fortran::lower::LoweringBridge &bridge;
244   Fortran::evaluate::FoldingContext foldingContext;
245   fir::FirOpBuilder *builder = nullptr;
246   Fortran::lower::pft::Evaluation *evalPtr = nullptr;
247   Fortran::lower::SymMap localSymbols;
248   Fortran::parser::CharBlock currentPosition;
249 };
250 
251 } // namespace
252 
253 Fortran::evaluate::FoldingContext
254 Fortran::lower::LoweringBridge::createFoldingContext() const {
255   return {getDefaultKinds(), getIntrinsicTable()};
256 }
257 
258 void Fortran::lower::LoweringBridge::lower(
259     const Fortran::parser::Program &prg,
260     const Fortran::semantics::SemanticsContext &semanticsContext) {
261   std::unique_ptr<Fortran::lower::pft::Program> pft =
262       Fortran::lower::createPFT(prg, semanticsContext);
263   if (dumpBeforeFir)
264     Fortran::lower::dumpPFT(llvm::errs(), *pft);
265   FirConverter converter{*this};
266   converter.run(*pft);
267 }
268 
269 Fortran::lower::LoweringBridge::LoweringBridge(
270     mlir::MLIRContext &context,
271     const Fortran::common::IntrinsicTypeDefaultKinds &defaultKinds,
272     const Fortran::evaluate::IntrinsicProcTable &intrinsics,
273     const Fortran::parser::AllCookedSources &cooked, llvm::StringRef triple,
274     fir::KindMapping &kindMap)
275     : defaultKinds{defaultKinds}, intrinsics{intrinsics}, cooked{&cooked},
276       context{context}, kindMap{kindMap} {
277   // Register the diagnostic handler.
278   context.getDiagEngine().registerHandler([](mlir::Diagnostic &diag) {
279     llvm::raw_ostream &os = llvm::errs();
280     switch (diag.getSeverity()) {
281     case mlir::DiagnosticSeverity::Error:
282       os << "error: ";
283       break;
284     case mlir::DiagnosticSeverity::Remark:
285       os << "info: ";
286       break;
287     case mlir::DiagnosticSeverity::Warning:
288       os << "warning: ";
289       break;
290     default:
291       break;
292     }
293     if (!diag.getLocation().isa<UnknownLoc>())
294       os << diag.getLocation() << ": ";
295     os << diag << '\n';
296     os.flush();
297     return mlir::success();
298   });
299 
300   // Create the module and attach the attributes.
301   module = std::make_unique<mlir::ModuleOp>(
302       mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)));
303   assert(module.get() && "module was not created");
304   fir::setTargetTriple(*module.get(), triple);
305   fir::setKindMapping(*module.get(), kindMap);
306 }
307