1 //===-- OpenMP.cpp -- Open MP directive lowering --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "flang/Lower/OpenMP.h"
14 #include "flang/Common/idioms.h"
15 #include "flang/Lower/Bridge.h"
16 #include "flang/Lower/PFTBuilder.h"
17 #include "flang/Lower/StatementContext.h"
18 #include "flang/Lower/Todo.h"
19 #include "flang/Optimizer/Builder/BoxValue.h"
20 #include "flang/Optimizer/Builder/FIRBuilder.h"
21 #include "flang/Parser/parse-tree.h"
22 #include "flang/Semantics/tools.h"
23 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
24 #include "llvm/Frontend/OpenMP/OMPConstants.h"
25 
26 using namespace mlir;
27 
28 static const Fortran::parser::Name *
29 getDesignatorNameIfDataRef(const Fortran::parser::Designator &designator) {
30   const auto *dataRef = std::get_if<Fortran::parser::DataRef>(&designator.u);
31   return dataRef ? std::get_if<Fortran::parser::Name>(&dataRef->u) : nullptr;
32 }
33 
34 static void genObjectList(const Fortran::parser::OmpObjectList &objectList,
35                           Fortran::lower::AbstractConverter &converter,
36                           SmallVectorImpl<Value> &operands) {
37   for (const auto &ompObject : objectList.v) {
38     std::visit(
39         Fortran::common::visitors{
40             [&](const Fortran::parser::Designator &designator) {
41               if (const auto *name = getDesignatorNameIfDataRef(designator)) {
42                 const auto variable = converter.getSymbolAddress(*name->symbol);
43                 operands.push_back(variable);
44               }
45             },
46             [&](const Fortran::parser::Name &name) {
47               const auto variable = converter.getSymbolAddress(*name.symbol);
48               operands.push_back(variable);
49             }},
50         ompObject.u);
51   }
52 }
53 
54 template <typename Op>
55 static void createBodyOfOp(Op &op, fir::FirOpBuilder &firOpBuilder,
56                            mlir::Location &loc) {
57   firOpBuilder.createBlock(&op.getRegion());
58   auto &block = op.getRegion().back();
59   firOpBuilder.setInsertionPointToStart(&block);
60   // Ensure the block is well-formed.
61   firOpBuilder.create<mlir::omp::TerminatorOp>(loc);
62   // Reset the insertion point to the start of the first block.
63   firOpBuilder.setInsertionPointToStart(&block);
64 }
65 
66 static void genOMP(Fortran::lower::AbstractConverter &converter,
67                    Fortran::lower::pft::Evaluation &eval,
68                    const Fortran::parser::OpenMPSimpleStandaloneConstruct
69                        &simpleStandaloneConstruct) {
70   const auto &directive =
71       std::get<Fortran::parser::OmpSimpleStandaloneDirective>(
72           simpleStandaloneConstruct.t);
73   switch (directive.v) {
74   default:
75     break;
76   case llvm::omp::Directive::OMPD_barrier:
77     converter.getFirOpBuilder().create<mlir::omp::BarrierOp>(
78         converter.getCurrentLocation());
79     break;
80   case llvm::omp::Directive::OMPD_taskwait:
81     converter.getFirOpBuilder().create<mlir::omp::TaskwaitOp>(
82         converter.getCurrentLocation());
83     break;
84   case llvm::omp::Directive::OMPD_taskyield:
85     converter.getFirOpBuilder().create<mlir::omp::TaskyieldOp>(
86         converter.getCurrentLocation());
87     break;
88   case llvm::omp::Directive::OMPD_target_enter_data:
89     TODO(converter.getCurrentLocation(), "OMPD_target_enter_data");
90   case llvm::omp::Directive::OMPD_target_exit_data:
91     TODO(converter.getCurrentLocation(), "OMPD_target_exit_data");
92   case llvm::omp::Directive::OMPD_target_update:
93     TODO(converter.getCurrentLocation(), "OMPD_target_update");
94   case llvm::omp::Directive::OMPD_ordered:
95     TODO(converter.getCurrentLocation(), "OMPD_ordered");
96   }
97 }
98 
99 static void
100 genAllocateClause(Fortran::lower::AbstractConverter &converter,
101                   const Fortran::parser::OmpAllocateClause &ompAllocateClause,
102                   SmallVector<Value> &allocatorOperands,
103                   SmallVector<Value> &allocateOperands) {
104   auto &firOpBuilder = converter.getFirOpBuilder();
105   auto currentLocation = converter.getCurrentLocation();
106   Fortran::lower::StatementContext stmtCtx;
107 
108   mlir::Value allocatorOperand;
109   const Fortran::parser::OmpObjectList &ompObjectList =
110       std::get<Fortran::parser::OmpObjectList>(ompAllocateClause.t);
111   const auto &allocatorValue =
112       std::get<std::optional<Fortran::parser::OmpAllocateClause::Allocator>>(
113           ompAllocateClause.t);
114   // Check if allocate clause has allocator specified. If so, add it
115   // to list of allocators, otherwise, add default allocator to
116   // list of allocators.
117   if (allocatorValue) {
118     allocatorOperand = fir::getBase(converter.genExprValue(
119         *Fortran::semantics::GetExpr(allocatorValue->v), stmtCtx));
120     allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(),
121                              allocatorOperand);
122   } else {
123     allocatorOperand = firOpBuilder.createIntegerConstant(
124         currentLocation, firOpBuilder.getI32Type(), 1);
125     allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(),
126                              allocatorOperand);
127   }
128   genObjectList(ompObjectList, converter, allocateOperands);
129 }
130 
131 static void
132 genOMP(Fortran::lower::AbstractConverter &converter,
133        Fortran::lower::pft::Evaluation &eval,
134        const Fortran::parser::OpenMPStandaloneConstruct &standaloneConstruct) {
135   std::visit(
136       Fortran::common::visitors{
137           [&](const Fortran::parser::OpenMPSimpleStandaloneConstruct
138                   &simpleStandaloneConstruct) {
139             genOMP(converter, eval, simpleStandaloneConstruct);
140           },
141           [&](const Fortran::parser::OpenMPFlushConstruct &flushConstruct) {
142             SmallVector<Value, 4> operandRange;
143             if (const auto &ompObjectList =
144                     std::get<std::optional<Fortran::parser::OmpObjectList>>(
145                         flushConstruct.t))
146               genObjectList(*ompObjectList, converter, operandRange);
147             const auto &memOrderClause = std::get<std::optional<
148                 std::list<Fortran::parser::OmpMemoryOrderClause>>>(
149                 flushConstruct.t);
150             if (memOrderClause.has_value() && memOrderClause->size() > 0)
151               TODO(converter.getCurrentLocation(),
152                    "Handle OmpMemoryOrderClause");
153             converter.getFirOpBuilder().create<mlir::omp::FlushOp>(
154                 converter.getCurrentLocation(), operandRange);
155           },
156           [&](const Fortran::parser::OpenMPCancelConstruct &cancelConstruct) {
157             TODO(converter.getCurrentLocation(), "OpenMPCancelConstruct");
158           },
159           [&](const Fortran::parser::OpenMPCancellationPointConstruct
160                   &cancellationPointConstruct) {
161             TODO(converter.getCurrentLocation(), "OpenMPCancelConstruct");
162           },
163       },
164       standaloneConstruct.u);
165 }
166 
167 static void
168 genOMP(Fortran::lower::AbstractConverter &converter,
169        Fortran::lower::pft::Evaluation &eval,
170        const Fortran::parser::OpenMPBlockConstruct &blockConstruct) {
171   const auto &beginBlockDirective =
172       std::get<Fortran::parser::OmpBeginBlockDirective>(blockConstruct.t);
173   const auto &blockDirective =
174       std::get<Fortran::parser::OmpBlockDirective>(beginBlockDirective.t);
175   const auto &endBlockDirective =
176       std::get<Fortran::parser::OmpEndBlockDirective>(blockConstruct.t);
177 
178   auto &firOpBuilder = converter.getFirOpBuilder();
179   auto currentLocation = converter.getCurrentLocation();
180   Fortran::lower::StatementContext stmtCtx;
181   llvm::ArrayRef<mlir::Type> argTy;
182   mlir::Value ifClauseOperand, numThreadsClauseOperand;
183   mlir::omp::ClauseProcBindKindAttr procBindKindAttr;
184   SmallVector<Value> allocateOperands, allocatorOperands;
185   mlir::UnitAttr nowaitAttr;
186 
187   for (const auto &clause :
188        std::get<Fortran::parser::OmpClauseList>(beginBlockDirective.t).v) {
189     if (const auto &ifClause =
190             std::get_if<Fortran::parser::OmpClause::If>(&clause.u)) {
191       auto &expr = std::get<Fortran::parser::ScalarLogicalExpr>(ifClause->v.t);
192       ifClauseOperand = fir::getBase(
193           converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx));
194     } else if (const auto &numThreadsClause =
195                    std::get_if<Fortran::parser::OmpClause::NumThreads>(
196                        &clause.u)) {
197       // OMPIRBuilder expects `NUM_THREAD` clause as a `Value`.
198       numThreadsClauseOperand = fir::getBase(converter.genExprValue(
199           *Fortran::semantics::GetExpr(numThreadsClause->v), stmtCtx));
200     } else if (const auto &procBindClause =
201                    std::get_if<Fortran::parser::OmpClause::ProcBind>(
202                        &clause.u)) {
203       omp::ClauseProcBindKind pbKind;
204       switch (procBindClause->v.v) {
205       case Fortran::parser::OmpProcBindClause::Type::Master:
206         pbKind = omp::ClauseProcBindKind::Master;
207         break;
208       case Fortran::parser::OmpProcBindClause::Type::Close:
209         pbKind = omp::ClauseProcBindKind::Close;
210         break;
211       case Fortran::parser::OmpProcBindClause::Type::Spread:
212         pbKind = omp::ClauseProcBindKind::Spread;
213         break;
214       case Fortran::parser::OmpProcBindClause::Type::Primary:
215         pbKind = omp::ClauseProcBindKind::Primary;
216         break;
217       }
218       procBindKindAttr =
219           omp::ClauseProcBindKindAttr::get(firOpBuilder.getContext(), pbKind);
220     } else if (const auto &allocateClause =
221                    std::get_if<Fortran::parser::OmpClause::Allocate>(
222                        &clause.u)) {
223       genAllocateClause(converter, allocateClause->v, allocatorOperands,
224                         allocateOperands);
225     } else if (const auto &privateClause =
226                    std::get_if<Fortran::parser::OmpClause::Private>(
227                        &clause.u)) {
228       // TODO: Handle private. This cannot be a hard TODO because testing for
229       // allocate clause requires private variables.
230     } else {
231       TODO(currentLocation, "OpenMP Block construct clauses");
232     }
233   }
234 
235   for (const auto &clause :
236        std::get<Fortran::parser::OmpClauseList>(endBlockDirective.t).v) {
237     if (std::get_if<Fortran::parser::OmpClause::Nowait>(&clause.u))
238       nowaitAttr = firOpBuilder.getUnitAttr();
239   }
240 
241   if (blockDirective.v == llvm::omp::OMPD_parallel) {
242     // Create and insert the operation.
243     auto parallelOp = firOpBuilder.create<mlir::omp::ParallelOp>(
244         currentLocation, argTy, ifClauseOperand, numThreadsClauseOperand,
245         allocateOperands, allocatorOperands, /*reduction_vars=*/ValueRange(),
246         /*reductions=*/nullptr, procBindKindAttr);
247     createBodyOfOp<omp::ParallelOp>(parallelOp, firOpBuilder, currentLocation);
248   } else if (blockDirective.v == llvm::omp::OMPD_master) {
249     auto masterOp =
250         firOpBuilder.create<mlir::omp::MasterOp>(currentLocation, argTy);
251     createBodyOfOp<omp::MasterOp>(masterOp, firOpBuilder, currentLocation);
252   } else if (blockDirective.v == llvm::omp::OMPD_single) {
253     auto singleOp = firOpBuilder.create<mlir::omp::SingleOp>(
254         currentLocation, allocateOperands, allocatorOperands, nowaitAttr);
255     createBodyOfOp(singleOp, firOpBuilder, currentLocation);
256   }
257 }
258 
259 static void
260 genOMP(Fortran::lower::AbstractConverter &converter,
261        Fortran::lower::pft::Evaluation &eval,
262        const Fortran::parser::OpenMPCriticalConstruct &criticalConstruct) {
263   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
264   mlir::Location currentLocation = converter.getCurrentLocation();
265   std::string name;
266   const Fortran::parser::OmpCriticalDirective &cd =
267       std::get<Fortran::parser::OmpCriticalDirective>(criticalConstruct.t);
268   if (std::get<std::optional<Fortran::parser::Name>>(cd.t).has_value()) {
269     name =
270         std::get<std::optional<Fortran::parser::Name>>(cd.t).value().ToString();
271   }
272 
273   uint64_t hint = 0;
274   const auto &clauseList = std::get<Fortran::parser::OmpClauseList>(cd.t);
275   for (const Fortran::parser::OmpClause &clause : clauseList.v)
276     if (auto hintClause =
277             std::get_if<Fortran::parser::OmpClause::Hint>(&clause.u)) {
278       const auto *expr = Fortran::semantics::GetExpr(hintClause->v);
279       hint = *Fortran::evaluate::ToInt64(*expr);
280       break;
281     }
282 
283   mlir::omp::CriticalOp criticalOp = [&]() {
284     if (name.empty()) {
285       return firOpBuilder.create<mlir::omp::CriticalOp>(currentLocation,
286                                                         FlatSymbolRefAttr());
287     } else {
288       mlir::ModuleOp module = firOpBuilder.getModule();
289       mlir::OpBuilder modBuilder(module.getBodyRegion());
290       auto global = module.lookupSymbol<mlir::omp::CriticalDeclareOp>(name);
291       if (!global)
292         global = modBuilder.create<mlir::omp::CriticalDeclareOp>(
293             currentLocation, name, hint);
294       return firOpBuilder.create<mlir::omp::CriticalOp>(
295           currentLocation, mlir::FlatSymbolRefAttr::get(
296                                firOpBuilder.getContext(), global.sym_name()));
297     }
298   }();
299   createBodyOfOp<omp::CriticalOp>(criticalOp, firOpBuilder, currentLocation);
300 }
301 
302 static void
303 genOMP(Fortran::lower::AbstractConverter &converter,
304        Fortran::lower::pft::Evaluation &eval,
305        const Fortran::parser::OpenMPSectionConstruct &sectionConstruct) {
306 
307   auto &firOpBuilder = converter.getFirOpBuilder();
308   auto currentLocation = converter.getCurrentLocation();
309   mlir::omp::SectionOp sectionOp =
310       firOpBuilder.create<mlir::omp::SectionOp>(currentLocation);
311   createBodyOfOp<omp::SectionOp>(sectionOp, firOpBuilder, currentLocation);
312 }
313 
314 // TODO: Add support for reduction
315 static void
316 genOMP(Fortran::lower::AbstractConverter &converter,
317        Fortran::lower::pft::Evaluation &eval,
318        const Fortran::parser::OpenMPSectionsConstruct &sectionsConstruct) {
319   auto &firOpBuilder = converter.getFirOpBuilder();
320   auto currentLocation = converter.getCurrentLocation();
321   SmallVector<Value> reductionVars, allocateOperands, allocatorOperands;
322   mlir::UnitAttr noWaitClauseOperand;
323   const auto &sectionsClauseList = std::get<Fortran::parser::OmpClauseList>(
324       std::get<Fortran::parser::OmpBeginSectionsDirective>(sectionsConstruct.t)
325           .t);
326   for (const Fortran::parser::OmpClause &clause : sectionsClauseList.v) {
327 
328     // Reduction Clause
329     if (std::get_if<Fortran::parser::OmpClause::Reduction>(&clause.u)) {
330       TODO(currentLocation, "OMPC_Reduction");
331 
332       // Allocate clause
333     } else if (const auto &allocateClause =
334                    std::get_if<Fortran::parser::OmpClause::Allocate>(
335                        &clause.u)) {
336       genAllocateClause(converter, allocateClause->v, allocatorOperands,
337                         allocateOperands);
338     }
339   }
340   const auto &endSectionsClauseList =
341       std::get<Fortran::parser::OmpEndSectionsDirective>(sectionsConstruct.t);
342   const auto &clauseList =
343       std::get<Fortran::parser::OmpClauseList>(endSectionsClauseList.t);
344   for (const auto &clause : clauseList.v) {
345     // Nowait clause
346     if (std::get_if<Fortran::parser::OmpClause::Nowait>(&clause.u)) {
347       noWaitClauseOperand = firOpBuilder.getUnitAttr();
348     }
349   }
350 
351   llvm::omp::Directive dir =
352       std::get<Fortran::parser::OmpSectionsDirective>(
353           std::get<Fortran::parser::OmpBeginSectionsDirective>(
354               sectionsConstruct.t)
355               .t)
356           .v;
357 
358   // Parallel Sections Construct
359   if (dir == llvm::omp::Directive::OMPD_parallel_sections) {
360     auto parallelOp = firOpBuilder.create<mlir::omp::ParallelOp>(
361         currentLocation, /*if_expr_var*/ nullptr, /*num_threads_var*/ nullptr,
362         allocateOperands, allocatorOperands, /*reduction_vars=*/ValueRange(),
363         /*reductions=*/nullptr, /*proc_bind_val*/ nullptr);
364     createBodyOfOp(parallelOp, firOpBuilder, currentLocation);
365     auto sectionsOp = firOpBuilder.create<mlir::omp::SectionsOp>(
366         currentLocation, /*reduction_vars*/ ValueRange(),
367         /*reductions=*/nullptr, /*allocate_vars*/ ValueRange(),
368         /*allocators_vars*/ ValueRange(), /*nowait=*/nullptr);
369     createBodyOfOp(sectionsOp, firOpBuilder, currentLocation);
370 
371     // Sections Construct
372   } else if (dir == llvm::omp::Directive::OMPD_sections) {
373     auto sectionsOp = firOpBuilder.create<mlir::omp::SectionsOp>(
374         currentLocation, reductionVars, /*reductions = */ nullptr,
375         allocateOperands, allocatorOperands, noWaitClauseOperand);
376     createBodyOfOp<omp::SectionsOp>(sectionsOp, firOpBuilder, currentLocation);
377   }
378 }
379 
380 void Fortran::lower::genOpenMPConstruct(
381     Fortran::lower::AbstractConverter &converter,
382     Fortran::lower::pft::Evaluation &eval,
383     const Fortran::parser::OpenMPConstruct &ompConstruct) {
384 
385   std::visit(
386       common::visitors{
387           [&](const Fortran::parser::OpenMPStandaloneConstruct
388                   &standaloneConstruct) {
389             genOMP(converter, eval, standaloneConstruct);
390           },
391           [&](const Fortran::parser::OpenMPSectionsConstruct
392                   &sectionsConstruct) {
393             genOMP(converter, eval, sectionsConstruct);
394           },
395           [&](const Fortran::parser::OpenMPSectionConstruct &sectionConstruct) {
396             genOMP(converter, eval, sectionConstruct);
397           },
398           [&](const Fortran::parser::OpenMPLoopConstruct &loopConstruct) {
399             TODO(converter.getCurrentLocation(), "OpenMPLoopConstruct");
400           },
401           [&](const Fortran::parser::OpenMPDeclarativeAllocate
402                   &execAllocConstruct) {
403             TODO(converter.getCurrentLocation(), "OpenMPDeclarativeAllocate");
404           },
405           [&](const Fortran::parser::OpenMPExecutableAllocate
406                   &execAllocConstruct) {
407             TODO(converter.getCurrentLocation(), "OpenMPExecutableAllocate");
408           },
409           [&](const Fortran::parser::OpenMPBlockConstruct &blockConstruct) {
410             genOMP(converter, eval, blockConstruct);
411           },
412           [&](const Fortran::parser::OpenMPAtomicConstruct &atomicConstruct) {
413             TODO(converter.getCurrentLocation(), "OpenMPAtomicConstruct");
414           },
415           [&](const Fortran::parser::OpenMPCriticalConstruct
416                   &criticalConstruct) {
417             genOMP(converter, eval, criticalConstruct);
418           },
419       },
420       ompConstruct.u);
421 }
422