1 //===-- OpenMP.cpp -- Open MP directive lowering --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "flang/Lower/OpenMP.h"
14 #include "flang/Common/idioms.h"
15 #include "flang/Lower/Bridge.h"
16 #include "flang/Lower/PFTBuilder.h"
17 #include "flang/Lower/StatementContext.h"
18 #include "flang/Lower/Todo.h"
19 #include "flang/Optimizer/Builder/BoxValue.h"
20 #include "flang/Optimizer/Builder/FIRBuilder.h"
21 #include "flang/Parser/parse-tree.h"
22 #include "flang/Semantics/tools.h"
23 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
24 #include "llvm/Frontend/OpenMP/OMPConstants.h"
25 
26 using namespace mlir;
27 
28 static const Fortran::parser::Name *
29 getDesignatorNameIfDataRef(const Fortran::parser::Designator &designator) {
30   const auto *dataRef = std::get_if<Fortran::parser::DataRef>(&designator.u);
31   return dataRef ? std::get_if<Fortran::parser::Name>(&dataRef->u) : nullptr;
32 }
33 
34 static void genObjectList(const Fortran::parser::OmpObjectList &objectList,
35                           Fortran::lower::AbstractConverter &converter,
36                           SmallVectorImpl<Value> &operands) {
37   for (const auto &ompObject : objectList.v) {
38     std::visit(
39         Fortran::common::visitors{
40             [&](const Fortran::parser::Designator &designator) {
41               if (const auto *name = getDesignatorNameIfDataRef(designator)) {
42                 const auto variable = converter.getSymbolAddress(*name->symbol);
43                 operands.push_back(variable);
44               }
45             },
46             [&](const Fortran::parser::Name &name) {
47               const auto variable = converter.getSymbolAddress(*name.symbol);
48               operands.push_back(variable);
49             }},
50         ompObject.u);
51   }
52 }
53 
54 template <typename Op>
55 static void createBodyOfOp(Op &op, fir::FirOpBuilder &firOpBuilder,
56                            mlir::Location &loc) {
57   firOpBuilder.createBlock(&op.getRegion());
58   auto &block = op.getRegion().back();
59   firOpBuilder.setInsertionPointToStart(&block);
60   // Ensure the block is well-formed.
61   firOpBuilder.create<mlir::omp::TerminatorOp>(loc);
62   // Reset the insertion point to the start of the first block.
63   firOpBuilder.setInsertionPointToStart(&block);
64 }
65 
66 static void genOMP(Fortran::lower::AbstractConverter &converter,
67                    Fortran::lower::pft::Evaluation &eval,
68                    const Fortran::parser::OpenMPSimpleStandaloneConstruct
69                        &simpleStandaloneConstruct) {
70   const auto &directive =
71       std::get<Fortran::parser::OmpSimpleStandaloneDirective>(
72           simpleStandaloneConstruct.t);
73   switch (directive.v) {
74   default:
75     break;
76   case llvm::omp::Directive::OMPD_barrier:
77     converter.getFirOpBuilder().create<mlir::omp::BarrierOp>(
78         converter.getCurrentLocation());
79     break;
80   case llvm::omp::Directive::OMPD_taskwait:
81     converter.getFirOpBuilder().create<mlir::omp::TaskwaitOp>(
82         converter.getCurrentLocation());
83     break;
84   case llvm::omp::Directive::OMPD_taskyield:
85     converter.getFirOpBuilder().create<mlir::omp::TaskyieldOp>(
86         converter.getCurrentLocation());
87     break;
88   case llvm::omp::Directive::OMPD_target_enter_data:
89     TODO(converter.getCurrentLocation(), "OMPD_target_enter_data");
90   case llvm::omp::Directive::OMPD_target_exit_data:
91     TODO(converter.getCurrentLocation(), "OMPD_target_exit_data");
92   case llvm::omp::Directive::OMPD_target_update:
93     TODO(converter.getCurrentLocation(), "OMPD_target_update");
94   case llvm::omp::Directive::OMPD_ordered:
95     TODO(converter.getCurrentLocation(), "OMPD_ordered");
96   }
97 }
98 
99 static void
100 genOMP(Fortran::lower::AbstractConverter &converter,
101        Fortran::lower::pft::Evaluation &eval,
102        const Fortran::parser::OpenMPStandaloneConstruct &standaloneConstruct) {
103   std::visit(
104       Fortran::common::visitors{
105           [&](const Fortran::parser::OpenMPSimpleStandaloneConstruct
106                   &simpleStandaloneConstruct) {
107             genOMP(converter, eval, simpleStandaloneConstruct);
108           },
109           [&](const Fortran::parser::OpenMPFlushConstruct &flushConstruct) {
110             SmallVector<Value, 4> operandRange;
111             if (const auto &ompObjectList =
112                     std::get<std::optional<Fortran::parser::OmpObjectList>>(
113                         flushConstruct.t))
114               genObjectList(*ompObjectList, converter, operandRange);
115             const auto &memOrderClause = std::get<std::optional<
116                 std::list<Fortran::parser::OmpMemoryOrderClause>>>(
117                 flushConstruct.t);
118             if (memOrderClause.has_value() && memOrderClause->size() > 0)
119               TODO(converter.getCurrentLocation(),
120                    "Handle OmpMemoryOrderClause");
121             converter.getFirOpBuilder().create<mlir::omp::FlushOp>(
122                 converter.getCurrentLocation(), operandRange);
123           },
124           [&](const Fortran::parser::OpenMPCancelConstruct &cancelConstruct) {
125             TODO(converter.getCurrentLocation(), "OpenMPCancelConstruct");
126           },
127           [&](const Fortran::parser::OpenMPCancellationPointConstruct
128                   &cancellationPointConstruct) {
129             TODO(converter.getCurrentLocation(), "OpenMPCancelConstruct");
130           },
131       },
132       standaloneConstruct.u);
133 }
134 
135 static void
136 genOMP(Fortran::lower::AbstractConverter &converter,
137        Fortran::lower::pft::Evaluation &eval,
138        const Fortran::parser::OpenMPBlockConstruct &blockConstruct) {
139   const auto &beginBlockDirective =
140       std::get<Fortran::parser::OmpBeginBlockDirective>(blockConstruct.t);
141   const auto &blockDirective =
142       std::get<Fortran::parser::OmpBlockDirective>(beginBlockDirective.t);
143 
144   auto &firOpBuilder = converter.getFirOpBuilder();
145   auto currentLocation = converter.getCurrentLocation();
146   Fortran::lower::StatementContext stmtCtx;
147   llvm::ArrayRef<mlir::Type> argTy;
148   if (blockDirective.v == llvm::omp::OMPD_parallel) {
149 
150     mlir::Value ifClauseOperand, numThreadsClauseOperand;
151     Attribute procBindClauseOperand;
152 
153     const auto &parallelOpClauseList =
154         std::get<Fortran::parser::OmpClauseList>(beginBlockDirective.t);
155     for (const auto &clause : parallelOpClauseList.v) {
156       if (const auto &ifClause =
157               std::get_if<Fortran::parser::OmpClause::If>(&clause.u)) {
158         auto &expr =
159             std::get<Fortran::parser::ScalarLogicalExpr>(ifClause->v.t);
160         ifClauseOperand = fir::getBase(converter.genExprValue(
161             *Fortran::semantics::GetExpr(expr), stmtCtx));
162       } else if (const auto &numThreadsClause =
163                      std::get_if<Fortran::parser::OmpClause::NumThreads>(
164                          &clause.u)) {
165         // OMPIRBuilder expects `NUM_THREAD` clause as a `Value`.
166         numThreadsClauseOperand = fir::getBase(converter.genExprValue(
167             *Fortran::semantics::GetExpr(numThreadsClause->v), stmtCtx));
168       }
169       // TODO: Handle private, firstprivate, shared and copyin
170     }
171     // Create and insert the operation.
172     auto parallelOp = firOpBuilder.create<mlir::omp::ParallelOp>(
173         currentLocation, argTy, ifClauseOperand, numThreadsClauseOperand,
174         ValueRange(), ValueRange(),
175         procBindClauseOperand.dyn_cast_or_null<omp::ClauseProcBindKindAttr>());
176     // Handle attribute based clauses.
177     for (const auto &clause : parallelOpClauseList.v) {
178       // TODO: Handle default clause
179       if (const auto &procBindClause =
180               std::get_if<Fortran::parser::OmpClause::ProcBind>(&clause.u)) {
181         const auto &ompProcBindClause{procBindClause->v};
182         omp::ClauseProcBindKind pbKind;
183         switch (ompProcBindClause.v) {
184         case Fortran::parser::OmpProcBindClause::Type::Master:
185           pbKind = omp::ClauseProcBindKind::Master;
186           break;
187         case Fortran::parser::OmpProcBindClause::Type::Close:
188           pbKind = omp::ClauseProcBindKind::Close;
189           break;
190         case Fortran::parser::OmpProcBindClause::Type::Spread:
191           pbKind = omp::ClauseProcBindKind::Spread;
192           break;
193         }
194         parallelOp.proc_bind_valAttr(omp::ClauseProcBindKindAttr::get(
195             firOpBuilder.getContext(), pbKind));
196       }
197     }
198     createBodyOfOp<omp::ParallelOp>(parallelOp, firOpBuilder, currentLocation);
199   } else if (blockDirective.v == llvm::omp::OMPD_master) {
200     auto masterOp =
201         firOpBuilder.create<mlir::omp::MasterOp>(currentLocation, argTy);
202     createBodyOfOp<omp::MasterOp>(masterOp, firOpBuilder, currentLocation);
203   }
204 }
205 
206 static void
207 genOMP(Fortran::lower::AbstractConverter &converter,
208        Fortran::lower::pft::Evaluation &eval,
209        const Fortran::parser::OpenMPCriticalConstruct &criticalConstruct) {
210   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
211   mlir::Location currentLocation = converter.getCurrentLocation();
212   std::string name;
213   const Fortran::parser::OmpCriticalDirective &cd =
214       std::get<Fortran::parser::OmpCriticalDirective>(criticalConstruct.t);
215   if (std::get<std::optional<Fortran::parser::Name>>(cd.t).has_value()) {
216     name =
217         std::get<std::optional<Fortran::parser::Name>>(cd.t).value().ToString();
218   }
219 
220   uint64_t hint = 0;
221   const auto &clauseList = std::get<Fortran::parser::OmpClauseList>(cd.t);
222   for (const Fortran::parser::OmpClause &clause : clauseList.v)
223     if (auto hintClause =
224             std::get_if<Fortran::parser::OmpClause::Hint>(&clause.u)) {
225       const auto *expr = Fortran::semantics::GetExpr(hintClause->v);
226       hint = *Fortran::evaluate::ToInt64(*expr);
227       break;
228     }
229 
230   mlir::omp::CriticalOp criticalOp = [&]() {
231     if (name.empty()) {
232       return firOpBuilder.create<mlir::omp::CriticalOp>(currentLocation,
233                                                         FlatSymbolRefAttr());
234     } else {
235       mlir::ModuleOp module = firOpBuilder.getModule();
236       mlir::OpBuilder modBuilder(module.getBodyRegion());
237       auto global = module.lookupSymbol<mlir::omp::CriticalDeclareOp>(name);
238       if (!global)
239         global = modBuilder.create<mlir::omp::CriticalDeclareOp>(
240             currentLocation, name, hint);
241       return firOpBuilder.create<mlir::omp::CriticalOp>(
242           currentLocation, mlir::FlatSymbolRefAttr::get(
243                                firOpBuilder.getContext(), global.sym_name()));
244     }
245   }();
246   createBodyOfOp<omp::CriticalOp>(criticalOp, firOpBuilder, currentLocation);
247 }
248 
249 void Fortran::lower::genOpenMPConstruct(
250     Fortran::lower::AbstractConverter &converter,
251     Fortran::lower::pft::Evaluation &eval,
252     const Fortran::parser::OpenMPConstruct &ompConstruct) {
253 
254   std::visit(
255       common::visitors{
256           [&](const Fortran::parser::OpenMPStandaloneConstruct
257                   &standaloneConstruct) {
258             genOMP(converter, eval, standaloneConstruct);
259           },
260           [&](const Fortran::parser::OpenMPSectionsConstruct
261                   &sectionsConstruct) {
262             TODO(converter.getCurrentLocation(), "OpenMPSectionsConstruct");
263           },
264           [&](const Fortran::parser::OpenMPSectionConstruct &sectionConstruct) {
265             TODO(converter.getCurrentLocation(), "OpenMPSectionConstruct");
266           },
267           [&](const Fortran::parser::OpenMPLoopConstruct &loopConstruct) {
268             TODO(converter.getCurrentLocation(), "OpenMPLoopConstruct");
269           },
270           [&](const Fortran::parser::OpenMPDeclarativeAllocate
271                   &execAllocConstruct) {
272             TODO(converter.getCurrentLocation(), "OpenMPDeclarativeAllocate");
273           },
274           [&](const Fortran::parser::OpenMPExecutableAllocate
275                   &execAllocConstruct) {
276             TODO(converter.getCurrentLocation(), "OpenMPExecutableAllocate");
277           },
278           [&](const Fortran::parser::OpenMPBlockConstruct &blockConstruct) {
279             genOMP(converter, eval, blockConstruct);
280           },
281           [&](const Fortran::parser::OpenMPAtomicConstruct &atomicConstruct) {
282             TODO(converter.getCurrentLocation(), "OpenMPAtomicConstruct");
283           },
284           [&](const Fortran::parser::OpenMPCriticalConstruct
285                   &criticalConstruct) {
286             genOMP(converter, eval, criticalConstruct);
287           },
288       },
289       ompConstruct.u);
290 }
291