1 //===-- OpenMP.cpp -- Open MP directive lowering --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "flang/Lower/OpenMP.h"
14 #include "flang/Common/idioms.h"
15 #include "flang/Lower/Bridge.h"
16 #include "flang/Lower/ConvertExpr.h"
17 #include "flang/Lower/PFTBuilder.h"
18 #include "flang/Lower/StatementContext.h"
19 #include "flang/Optimizer/Builder/BoxValue.h"
20 #include "flang/Optimizer/Builder/FIRBuilder.h"
21 #include "flang/Optimizer/Builder/Todo.h"
22 #include "flang/Parser/parse-tree.h"
23 #include "flang/Semantics/tools.h"
24 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
25 #include "mlir/Dialect/SCF/IR/SCF.h"
26 #include "llvm/Frontend/OpenMP/OMPConstants.h"
27 
28 using namespace mlir;
29 
getCollapseValue(const Fortran::parser::OmpClauseList & clauseList)30 int64_t Fortran::lower::getCollapseValue(
31     const Fortran::parser::OmpClauseList &clauseList) {
32   for (const auto &clause : clauseList.v) {
33     if (const auto &collapseClause =
34             std::get_if<Fortran::parser::OmpClause::Collapse>(&clause.u)) {
35       const auto *expr = Fortran::semantics::GetExpr(collapseClause->v);
36       return Fortran::evaluate::ToInt64(*expr).value();
37     }
38   }
39   return 1;
40 }
41 
42 static const Fortran::parser::Name *
getDesignatorNameIfDataRef(const Fortran::parser::Designator & designator)43 getDesignatorNameIfDataRef(const Fortran::parser::Designator &designator) {
44   const auto *dataRef = std::get_if<Fortran::parser::DataRef>(&designator.u);
45   return dataRef ? std::get_if<Fortran::parser::Name>(&dataRef->u) : nullptr;
46 }
47 
48 static Fortran::semantics::Symbol *
getOmpObjectSymbol(const Fortran::parser::OmpObject & ompObject)49 getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject) {
50   Fortran::semantics::Symbol *sym = nullptr;
51   std::visit(Fortran::common::visitors{
52                  [&](const Fortran::parser::Designator &designator) {
53                    if (const Fortran::parser::Name *name =
54                            getDesignatorNameIfDataRef(designator)) {
55                      sym = name->symbol;
56                    }
57                  },
58                  [&](const Fortran::parser::Name &name) { sym = name.symbol; }},
59              ompObject.u);
60   return sym;
61 }
62 
63 template <typename T>
createPrivateVarSyms(Fortran::lower::AbstractConverter & converter,const T * clause,Block * lastPrivBlock=nullptr)64 static void createPrivateVarSyms(Fortran::lower::AbstractConverter &converter,
65                                  const T *clause,
66                                  Block *lastPrivBlock = nullptr) {
67   const Fortran::parser::OmpObjectList &ompObjectList = clause->v;
68   for (const Fortran::parser::OmpObject &ompObject : ompObjectList.v) {
69     Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject);
70     // Privatization for symbols which are pre-determined (like loop index
71     // variables) happen separately, for everything else privatize here.
72     if (sym->test(Fortran::semantics::Symbol::Flag::OmpPreDetermined))
73       continue;
74     bool success = converter.createHostAssociateVarClone(*sym);
75     (void)success;
76     assert(success && "Privatization failed due to existing binding");
77     if constexpr (std::is_same_v<T, Fortran::parser::OmpClause::Firstprivate>) {
78       converter.copyHostAssociateVar(*sym);
79     } else if constexpr (std::is_same_v<
80                              T, Fortran::parser::OmpClause::Lastprivate>) {
81       converter.copyHostAssociateVar(*sym, lastPrivBlock);
82     }
83   }
84 }
85 
86 template <typename Op>
privatizeVars(Op & op,Fortran::lower::AbstractConverter & converter,const Fortran::parser::OmpClauseList & opClauseList)87 static bool privatizeVars(Op &op, Fortran::lower::AbstractConverter &converter,
88                           const Fortran::parser::OmpClauseList &opClauseList) {
89   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
90   auto insPt = firOpBuilder.saveInsertionPoint();
91   firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock());
92   bool hasFirstPrivateOp = false;
93   bool hasLastPrivateOp = false;
94   // We need just one ICmpOp for multiple LastPrivate clauses.
95   mlir::arith::CmpIOp cmpOp;
96 
97   for (const Fortran::parser::OmpClause &clause : opClauseList.v) {
98     if (const auto &privateClause =
99             std::get_if<Fortran::parser::OmpClause::Private>(&clause.u)) {
100       createPrivateVarSyms(converter, privateClause);
101     } else if (const auto &firstPrivateClause =
102                    std::get_if<Fortran::parser::OmpClause::Firstprivate>(
103                        &clause.u)) {
104       createPrivateVarSyms(converter, firstPrivateClause);
105       hasFirstPrivateOp = true;
106     } else if (const auto &lastPrivateClause =
107                    std::get_if<Fortran::parser::OmpClause::Lastprivate>(
108                        &clause.u)) {
109       // TODO: Add lastprivate support for sections construct, simd construct
110       if (std::is_same_v<Op, omp::WsLoopOp>) {
111         omp::WsLoopOp *wsLoopOp = dyn_cast<omp::WsLoopOp>(&op);
112         fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
113         auto insPt = firOpBuilder.saveInsertionPoint();
114 
115         // Our goal here is to introduce the following control flow
116         // just before exiting the worksharing loop.
117         // Say our wsloop is as follows:
118         //
119         // omp.wsloop {
120         //    ...
121         //    store
122         //    omp.yield
123         // }
124         //
125         // We want to convert it to the following:
126         //
127         // omp.wsloop {
128         //    ...
129         //    store
130         //    %cmp = llvm.icmp "eq" %iv %ub
131         //    scf.if %cmp {
132         //      ^%lpv_update_blk:
133         //    }
134         //    omp.yield
135         // }
136 
137         Operation *lastOper = wsLoopOp->region().back().getTerminator();
138 
139         firOpBuilder.setInsertionPoint(lastOper);
140 
141         // TODO: The following will not work when there is collapse present.
142         // Have to modify this in future.
143         for (const Fortran::parser::OmpClause &clause : opClauseList.v)
144           if (const auto &collapseClause =
145                   std::get_if<Fortran::parser::OmpClause::Collapse>(&clause.u))
146             TODO(converter.getCurrentLocation(),
147                  "Collapse clause with lastprivate");
148         // Only generate the compare once in presence of multiple LastPrivate
149         // clauses
150         if (!hasLastPrivateOp) {
151           cmpOp = firOpBuilder.create<mlir::arith::CmpIOp>(
152               wsLoopOp->getLoc(), mlir::arith::CmpIPredicate::eq,
153               wsLoopOp->getRegion().front().getArguments()[0],
154               wsLoopOp->upperBound()[0]);
155         }
156         mlir::scf::IfOp ifOp = firOpBuilder.create<mlir::scf::IfOp>(
157             wsLoopOp->getLoc(), cmpOp, /*else*/ false);
158 
159         firOpBuilder.restoreInsertionPoint(insPt);
160         createPrivateVarSyms(converter, lastPrivateClause,
161                              &(ifOp.getThenRegion().front()));
162       } else {
163         TODO(converter.getCurrentLocation(),
164              "lastprivate clause in constructs other than work-share loop");
165       }
166       hasLastPrivateOp = true;
167     }
168   }
169   if (hasFirstPrivateOp)
170     firOpBuilder.create<mlir::omp::BarrierOp>(converter.getCurrentLocation());
171   firOpBuilder.restoreInsertionPoint(insPt);
172   return hasLastPrivateOp;
173 }
174 
175 /// The COMMON block is a global structure. \p commonValue is the base address
176 /// of the the COMMON block. As the offset from the symbol \p sym, generate the
177 /// COMMON block member value (commonValue + offset) for the symbol.
178 /// FIXME: Share the code with `instantiateCommon` in ConvertVariable.cpp.
179 static mlir::Value
genCommonBlockMember(Fortran::lower::AbstractConverter & converter,const Fortran::semantics::Symbol & sym,mlir::Value commonValue)180 genCommonBlockMember(Fortran::lower::AbstractConverter &converter,
181                      const Fortran::semantics::Symbol &sym,
182                      mlir::Value commonValue) {
183   auto &firOpBuilder = converter.getFirOpBuilder();
184   mlir::Location currentLocation = converter.getCurrentLocation();
185   mlir::IntegerType i8Ty = firOpBuilder.getIntegerType(8);
186   mlir::Type i8Ptr = firOpBuilder.getRefType(i8Ty);
187   mlir::Type seqTy = firOpBuilder.getRefType(firOpBuilder.getVarLenSeqTy(i8Ty));
188   mlir::Value base =
189       firOpBuilder.createConvert(currentLocation, seqTy, commonValue);
190   std::size_t byteOffset = sym.GetUltimate().offset();
191   mlir::Value offs = firOpBuilder.createIntegerConstant(
192       currentLocation, firOpBuilder.getIndexType(), byteOffset);
193   mlir::Value varAddr = firOpBuilder.create<fir::CoordinateOp>(
194       currentLocation, i8Ptr, base, mlir::ValueRange{offs});
195   mlir::Type symType = converter.genType(sym);
196   return firOpBuilder.createConvert(currentLocation,
197                                     firOpBuilder.getRefType(symType), varAddr);
198 }
199 
200 // Get the extended value for \p val by extracting additional variable
201 // information from \p base.
getExtendedValue(fir::ExtendedValue base,mlir::Value val)202 static fir::ExtendedValue getExtendedValue(fir::ExtendedValue base,
203                                            mlir::Value val) {
204   return base.match(
205       [&](const fir::MutableBoxValue &box) -> fir::ExtendedValue {
206         return fir::MutableBoxValue(val, box.nonDeferredLenParams(), {});
207       },
208       [&](const auto &) -> fir::ExtendedValue {
209         return fir::substBase(base, val);
210       });
211 }
212 
threadPrivatizeVars(Fortran::lower::AbstractConverter & converter,Fortran::lower::pft::Evaluation & eval)213 static void threadPrivatizeVars(Fortran::lower::AbstractConverter &converter,
214                                 Fortran::lower::pft::Evaluation &eval) {
215   auto &firOpBuilder = converter.getFirOpBuilder();
216   mlir::Location currentLocation = converter.getCurrentLocation();
217   auto insPt = firOpBuilder.saveInsertionPoint();
218   firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock());
219 
220   // Get the original ThreadprivateOp corresponding to the symbol and use the
221   // symbol value from that opeartion to create one ThreadprivateOp copy
222   // operation inside the parallel region.
223   auto genThreadprivateOp = [&](Fortran::lower::SymbolRef sym) -> mlir::Value {
224     mlir::Value symOriThreadprivateValue = converter.getSymbolAddress(sym);
225     mlir::Operation *op = symOriThreadprivateValue.getDefiningOp();
226     assert(mlir::isa<mlir::omp::ThreadprivateOp>(op) &&
227            "The threadprivate operation not created");
228     mlir::Value symValue =
229         mlir::dyn_cast<mlir::omp::ThreadprivateOp>(op).sym_addr();
230     return firOpBuilder.create<mlir::omp::ThreadprivateOp>(
231         currentLocation, symValue.getType(), symValue);
232   };
233 
234   llvm::SetVector<const Fortran::semantics::Symbol *> threadprivateSyms;
235   converter.collectSymbolSet(eval, threadprivateSyms,
236                              Fortran::semantics::Symbol::Flag::OmpThreadprivate,
237                              /*isUltimateSymbol=*/false);
238   std::set<Fortran::semantics::SourceName> threadprivateSymNames;
239 
240   // For a COMMON block, the ThreadprivateOp is generated for itself instead of
241   // its members, so only bind the value of the new copied ThreadprivateOp
242   // inside the parallel region to the common block symbol only once for
243   // multiple members in one COMMON block.
244   llvm::SetVector<const Fortran::semantics::Symbol *> commonSyms;
245   for (std::size_t i = 0; i < threadprivateSyms.size(); i++) {
246     auto sym = threadprivateSyms[i];
247     mlir::Value symThreadprivateValue;
248     // The variable may be used more than once, and each reference has one
249     // symbol with the same name. Only do once for references of one variable.
250     if (threadprivateSymNames.find(sym->name()) != threadprivateSymNames.end())
251       continue;
252     threadprivateSymNames.insert(sym->name());
253     if (const Fortran::semantics::Symbol *common =
254             Fortran::semantics::FindCommonBlockContaining(sym->GetUltimate())) {
255       mlir::Value commonThreadprivateValue;
256       if (commonSyms.contains(common)) {
257         commonThreadprivateValue = converter.getSymbolAddress(*common);
258       } else {
259         commonThreadprivateValue = genThreadprivateOp(*common);
260         converter.bindSymbol(*common, commonThreadprivateValue);
261         commonSyms.insert(common);
262       }
263       symThreadprivateValue =
264           genCommonBlockMember(converter, *sym, commonThreadprivateValue);
265     } else {
266       symThreadprivateValue = genThreadprivateOp(*sym);
267     }
268 
269     fir::ExtendedValue sexv = converter.getSymbolExtendedValue(*sym);
270     fir::ExtendedValue symThreadprivateExv =
271         getExtendedValue(sexv, symThreadprivateValue);
272     converter.bindSymbol(*sym, symThreadprivateExv);
273   }
274 
275   firOpBuilder.restoreInsertionPoint(insPt);
276 }
277 
278 static void
genCopyinClause(Fortran::lower::AbstractConverter & converter,const Fortran::parser::OmpClauseList & opClauseList)279 genCopyinClause(Fortran::lower::AbstractConverter &converter,
280                 const Fortran::parser::OmpClauseList &opClauseList) {
281   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
282   mlir::OpBuilder::InsertPoint insPt = firOpBuilder.saveInsertionPoint();
283   firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock());
284   bool hasCopyin = false;
285   for (const Fortran::parser::OmpClause &clause : opClauseList.v) {
286     if (const auto &copyinClause =
287             std::get_if<Fortran::parser::OmpClause::Copyin>(&clause.u)) {
288       hasCopyin = true;
289       const Fortran::parser::OmpObjectList &ompObjectList = copyinClause->v;
290       for (const Fortran::parser::OmpObject &ompObject : ompObjectList.v) {
291         Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject);
292         if (sym->has<Fortran::semantics::CommonBlockDetails>())
293           TODO(converter.getCurrentLocation(), "common block in Copyin clause");
294         if (Fortran::semantics::IsAllocatableOrPointer(sym->GetUltimate()))
295           TODO(converter.getCurrentLocation(),
296                "pointer or allocatable variables in Copyin clause");
297         assert(sym->has<Fortran::semantics::HostAssocDetails>() &&
298                "No host-association found");
299         converter.copyHostAssociateVar(*sym);
300       }
301     }
302   }
303   // [OMP 5.0, 2.19.6.1] The copy is done after the team is formed and prior to
304   // the execution of the associated structured block. Emit implicit barrier to
305   // synchronize threads and avoid data races on propagation master's thread
306   // values of threadprivate variables to local instances of that variables of
307   // all other implicit threads.
308   if (hasCopyin)
309     firOpBuilder.create<mlir::omp::BarrierOp>(converter.getCurrentLocation());
310   firOpBuilder.restoreInsertionPoint(insPt);
311 }
312 
genObjectList(const Fortran::parser::OmpObjectList & objectList,Fortran::lower::AbstractConverter & converter,llvm::SmallVectorImpl<Value> & operands)313 static void genObjectList(const Fortran::parser::OmpObjectList &objectList,
314                           Fortran::lower::AbstractConverter &converter,
315                           llvm::SmallVectorImpl<Value> &operands) {
316   auto addOperands = [&](Fortran::lower::SymbolRef sym) {
317     const mlir::Value variable = converter.getSymbolAddress(sym);
318     if (variable) {
319       operands.push_back(variable);
320     } else {
321       if (const auto *details =
322               sym->detailsIf<Fortran::semantics::HostAssocDetails>()) {
323         operands.push_back(converter.getSymbolAddress(details->symbol()));
324         converter.copySymbolBinding(details->symbol(), sym);
325       }
326     }
327   };
328   for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
329     Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject);
330     addOperands(*sym);
331   }
332 }
333 
getLoopVarType(Fortran::lower::AbstractConverter & converter,std::size_t loopVarTypeSize)334 static mlir::Type getLoopVarType(Fortran::lower::AbstractConverter &converter,
335                                  std::size_t loopVarTypeSize) {
336   // OpenMP runtime requires 32-bit or 64-bit loop variables.
337   loopVarTypeSize = loopVarTypeSize * 8;
338   if (loopVarTypeSize < 32) {
339     loopVarTypeSize = 32;
340   } else if (loopVarTypeSize > 64) {
341     loopVarTypeSize = 64;
342     mlir::emitWarning(converter.getCurrentLocation(),
343                       "OpenMP loop iteration variable cannot have more than 64 "
344                       "bits size and will be narrowed into 64 bits.");
345   }
346   assert((loopVarTypeSize == 32 || loopVarTypeSize == 64) &&
347          "OpenMP loop iteration variable size must be transformed into 32-bit "
348          "or 64-bit");
349   return converter.getFirOpBuilder().getIntegerType(loopVarTypeSize);
350 }
351 
352 /// Create empty blocks for the current region.
353 /// These blocks replace blocks parented to an enclosing region.
createEmptyRegionBlocks(fir::FirOpBuilder & firOpBuilder,std::list<Fortran::lower::pft::Evaluation> & evaluationList)354 void createEmptyRegionBlocks(
355     fir::FirOpBuilder &firOpBuilder,
356     std::list<Fortran::lower::pft::Evaluation> &evaluationList) {
357   auto *region = &firOpBuilder.getRegion();
358   for (auto &eval : evaluationList) {
359     if (eval.block) {
360       if (eval.block->empty()) {
361         eval.block->erase();
362         eval.block = firOpBuilder.createBlock(region);
363       } else {
364         [[maybe_unused]] auto &terminatorOp = eval.block->back();
365         assert((mlir::isa<mlir::omp::TerminatorOp>(terminatorOp) ||
366                 mlir::isa<mlir::omp::YieldOp>(terminatorOp)) &&
367                "expected terminator op");
368       }
369     }
370     if (!eval.isDirective() && eval.hasNestedEvaluations())
371       createEmptyRegionBlocks(firOpBuilder, eval.getNestedEvaluations());
372   }
373 }
374 
resetBeforeTerminator(fir::FirOpBuilder & firOpBuilder,mlir::Operation * storeOp,mlir::Block & block)375 void resetBeforeTerminator(fir::FirOpBuilder &firOpBuilder,
376                            mlir::Operation *storeOp, mlir::Block &block) {
377   if (storeOp)
378     firOpBuilder.setInsertionPointAfter(storeOp);
379   else
380     firOpBuilder.setInsertionPointToStart(&block);
381 }
382 
383 /// Create the body (block) for an OpenMP Operation.
384 ///
385 /// \param [in]    op - the operation the body belongs to.
386 /// \param [inout] converter - converter to use for the clauses.
387 /// \param [in]    loc - location in source code.
388 /// \param [in]    eval - current PFT node/evaluation.
389 /// \oaran [in]    clauses - list of clauses to process.
390 /// \param [in]    args - block arguments (induction variable[s]) for the
391 ////                      region.
392 /// \param [in]    outerCombined - is this an outer operation - prevents
393 ///                                privatization.
394 template <typename Op>
395 static void
createBodyOfOp(Op & op,Fortran::lower::AbstractConverter & converter,mlir::Location & loc,Fortran::lower::pft::Evaluation & eval,const Fortran::parser::OmpClauseList * clauses=nullptr,const SmallVector<const Fortran::semantics::Symbol * > & args={},bool outerCombined=false)396 createBodyOfOp(Op &op, Fortran::lower::AbstractConverter &converter,
397                mlir::Location &loc, Fortran::lower::pft::Evaluation &eval,
398                const Fortran::parser::OmpClauseList *clauses = nullptr,
399                const SmallVector<const Fortran::semantics::Symbol *> &args = {},
400                bool outerCombined = false) {
401   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
402   // If an argument for the region is provided then create the block with that
403   // argument. Also update the symbol's address with the mlir argument value.
404   // e.g. For loops the argument is the induction variable. And all further
405   // uses of the induction variable should use this mlir value.
406   mlir::Operation *storeOp = nullptr;
407   if (args.size()) {
408     std::size_t loopVarTypeSize = 0;
409     for (const Fortran::semantics::Symbol *arg : args)
410       loopVarTypeSize = std::max(loopVarTypeSize, arg->GetUltimate().size());
411     mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
412     SmallVector<Type> tiv;
413     SmallVector<Location> locs;
414     for (int i = 0; i < (int)args.size(); i++) {
415       tiv.push_back(loopVarType);
416       locs.push_back(loc);
417     }
418     firOpBuilder.createBlock(&op.getRegion(), {}, tiv, locs);
419     int argIndex = 0;
420     // The argument is not currently in memory, so make a temporary for the
421     // argument, and store it there, then bind that location to the argument.
422     for (const Fortran::semantics::Symbol *arg : args) {
423       mlir::Value val =
424           fir::getBase(op.getRegion().front().getArgument(argIndex));
425       mlir::Value temp = firOpBuilder.createTemporary(
426           loc, loopVarType,
427           llvm::ArrayRef<mlir::NamedAttribute>{
428               Fortran::lower::getAdaptToByRefAttr(firOpBuilder)});
429       storeOp = firOpBuilder.create<fir::StoreOp>(loc, val, temp);
430       converter.bindSymbol(*arg, temp);
431       argIndex++;
432     }
433   } else {
434     firOpBuilder.createBlock(&op.getRegion());
435   }
436   // Set the insert for the terminator operation to go at the end of the
437   // block - this is either empty or the block with the stores above,
438   // the end of the block works for both.
439   mlir::Block &block = op.getRegion().back();
440   firOpBuilder.setInsertionPointToEnd(&block);
441 
442   // If it is an unstructured region and is not the outer region of a combined
443   // construct, create empty blocks for all evaluations.
444   if (eval.lowerAsUnstructured() && !outerCombined)
445     createEmptyRegionBlocks(firOpBuilder, eval.getNestedEvaluations());
446 
447   // Insert the terminator.
448   if constexpr (std::is_same_v<Op, omp::WsLoopOp> ||
449                 std::is_same_v<Op, omp::SimdLoopOp>) {
450     mlir::ValueRange results;
451     firOpBuilder.create<mlir::omp::YieldOp>(loc, results);
452   } else {
453     firOpBuilder.create<mlir::omp::TerminatorOp>(loc);
454   }
455 
456   // Reset the insert point to before the terminator.
457   resetBeforeTerminator(firOpBuilder, storeOp, block);
458 
459   // Handle privatization. Do not privatize if this is the outer operation.
460   if (clauses && !outerCombined) {
461     bool lastPrivateOp = privatizeVars(op, converter, *clauses);
462     // LastPrivatization, due to introduction of
463     // new control flow, changes the insertion point,
464     // thus restore it.
465     // TODO: Clean up later a bit to avoid this many sets and resets.
466     if (lastPrivateOp)
467       resetBeforeTerminator(firOpBuilder, storeOp, block);
468   }
469 
470   if constexpr (std::is_same_v<Op, omp::ParallelOp>) {
471     threadPrivatizeVars(converter, eval);
472     if (clauses)
473       genCopyinClause(converter, *clauses);
474   }
475 }
476 
genOMP(Fortran::lower::AbstractConverter & converter,Fortran::lower::pft::Evaluation & eval,const Fortran::parser::OpenMPSimpleStandaloneConstruct & simpleStandaloneConstruct)477 static void genOMP(Fortran::lower::AbstractConverter &converter,
478                    Fortran::lower::pft::Evaluation &eval,
479                    const Fortran::parser::OpenMPSimpleStandaloneConstruct
480                        &simpleStandaloneConstruct) {
481   const auto &directive =
482       std::get<Fortran::parser::OmpSimpleStandaloneDirective>(
483           simpleStandaloneConstruct.t);
484   switch (directive.v) {
485   default:
486     break;
487   case llvm::omp::Directive::OMPD_barrier:
488     converter.getFirOpBuilder().create<mlir::omp::BarrierOp>(
489         converter.getCurrentLocation());
490     break;
491   case llvm::omp::Directive::OMPD_taskwait:
492     converter.getFirOpBuilder().create<mlir::omp::TaskwaitOp>(
493         converter.getCurrentLocation());
494     break;
495   case llvm::omp::Directive::OMPD_taskyield:
496     converter.getFirOpBuilder().create<mlir::omp::TaskyieldOp>(
497         converter.getCurrentLocation());
498     break;
499   case llvm::omp::Directive::OMPD_target_enter_data:
500     TODO(converter.getCurrentLocation(), "OMPD_target_enter_data");
501   case llvm::omp::Directive::OMPD_target_exit_data:
502     TODO(converter.getCurrentLocation(), "OMPD_target_exit_data");
503   case llvm::omp::Directive::OMPD_target_update:
504     TODO(converter.getCurrentLocation(), "OMPD_target_update");
505   case llvm::omp::Directive::OMPD_ordered:
506     TODO(converter.getCurrentLocation(), "OMPD_ordered");
507   }
508 }
509 
510 static void
genAllocateClause(Fortran::lower::AbstractConverter & converter,const Fortran::parser::OmpAllocateClause & ompAllocateClause,SmallVector<Value> & allocatorOperands,SmallVector<Value> & allocateOperands)511 genAllocateClause(Fortran::lower::AbstractConverter &converter,
512                   const Fortran::parser::OmpAllocateClause &ompAllocateClause,
513                   SmallVector<Value> &allocatorOperands,
514                   SmallVector<Value> &allocateOperands) {
515   auto &firOpBuilder = converter.getFirOpBuilder();
516   auto currentLocation = converter.getCurrentLocation();
517   Fortran::lower::StatementContext stmtCtx;
518 
519   mlir::Value allocatorOperand;
520   const Fortran::parser::OmpObjectList &ompObjectList =
521       std::get<Fortran::parser::OmpObjectList>(ompAllocateClause.t);
522   const auto &allocatorValue =
523       std::get<std::optional<Fortran::parser::OmpAllocateClause::Allocator>>(
524           ompAllocateClause.t);
525   // Check if allocate clause has allocator specified. If so, add it
526   // to list of allocators, otherwise, add default allocator to
527   // list of allocators.
528   if (allocatorValue) {
529     allocatorOperand = fir::getBase(converter.genExprValue(
530         *Fortran::semantics::GetExpr(allocatorValue->v), stmtCtx));
531     allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(),
532                              allocatorOperand);
533   } else {
534     allocatorOperand = firOpBuilder.createIntegerConstant(
535         currentLocation, firOpBuilder.getI32Type(), 1);
536     allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(),
537                              allocatorOperand);
538   }
539   genObjectList(ompObjectList, converter, allocateOperands);
540 }
541 
542 static void
genOMP(Fortran::lower::AbstractConverter & converter,Fortran::lower::pft::Evaluation & eval,const Fortran::parser::OpenMPStandaloneConstruct & standaloneConstruct)543 genOMP(Fortran::lower::AbstractConverter &converter,
544        Fortran::lower::pft::Evaluation &eval,
545        const Fortran::parser::OpenMPStandaloneConstruct &standaloneConstruct) {
546   std::visit(
547       Fortran::common::visitors{
548           [&](const Fortran::parser::OpenMPSimpleStandaloneConstruct
549                   &simpleStandaloneConstruct) {
550             genOMP(converter, eval, simpleStandaloneConstruct);
551           },
552           [&](const Fortran::parser::OpenMPFlushConstruct &flushConstruct) {
553             SmallVector<Value, 4> operandRange;
554             if (const auto &ompObjectList =
555                     std::get<std::optional<Fortran::parser::OmpObjectList>>(
556                         flushConstruct.t))
557               genObjectList(*ompObjectList, converter, operandRange);
558             const auto &memOrderClause = std::get<std::optional<
559                 std::list<Fortran::parser::OmpMemoryOrderClause>>>(
560                 flushConstruct.t);
561             if (memOrderClause.has_value() && memOrderClause->size() > 0)
562               TODO(converter.getCurrentLocation(),
563                    "Handle OmpMemoryOrderClause");
564             converter.getFirOpBuilder().create<mlir::omp::FlushOp>(
565                 converter.getCurrentLocation(), operandRange);
566           },
567           [&](const Fortran::parser::OpenMPCancelConstruct &cancelConstruct) {
568             TODO(converter.getCurrentLocation(), "OpenMPCancelConstruct");
569           },
570           [&](const Fortran::parser::OpenMPCancellationPointConstruct
571                   &cancellationPointConstruct) {
572             TODO(converter.getCurrentLocation(), "OpenMPCancelConstruct");
573           },
574       },
575       standaloneConstruct.u);
576 }
577 
genProcBindKindAttr(fir::FirOpBuilder & firOpBuilder,const Fortran::parser::OmpClause::ProcBind * procBindClause)578 static omp::ClauseProcBindKindAttr genProcBindKindAttr(
579     fir::FirOpBuilder &firOpBuilder,
580     const Fortran::parser::OmpClause::ProcBind *procBindClause) {
581   omp::ClauseProcBindKind pbKind;
582   switch (procBindClause->v.v) {
583   case Fortran::parser::OmpProcBindClause::Type::Master:
584     pbKind = omp::ClauseProcBindKind::Master;
585     break;
586   case Fortran::parser::OmpProcBindClause::Type::Close:
587     pbKind = omp::ClauseProcBindKind::Close;
588     break;
589   case Fortran::parser::OmpProcBindClause::Type::Spread:
590     pbKind = omp::ClauseProcBindKind::Spread;
591     break;
592   case Fortran::parser::OmpProcBindClause::Type::Primary:
593     pbKind = omp::ClauseProcBindKind::Primary;
594     break;
595   }
596   return omp::ClauseProcBindKindAttr::get(firOpBuilder.getContext(), pbKind);
597 }
598 
599 static mlir::Value
getIfClauseOperand(Fortran::lower::AbstractConverter & converter,Fortran::lower::StatementContext & stmtCtx,const Fortran::parser::OmpClause::If * ifClause)600 getIfClauseOperand(Fortran::lower::AbstractConverter &converter,
601                    Fortran::lower::StatementContext &stmtCtx,
602                    const Fortran::parser::OmpClause::If *ifClause) {
603   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
604   mlir::Location currentLocation = converter.getCurrentLocation();
605   auto &expr = std::get<Fortran::parser::ScalarLogicalExpr>(ifClause->v.t);
606   mlir::Value ifVal = fir::getBase(
607       converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx));
608   return firOpBuilder.createConvert(currentLocation, firOpBuilder.getI1Type(),
609                                     ifVal);
610 }
611 
612 /* When parallel is used in a combined construct, then use this function to
613  * create the parallel operation. It handles the parallel specific clauses
614  * and leaves the rest for handling at the inner operations.
615  * TODO: Refactor clause handling
616  */
617 template <typename Directive>
618 static void
createCombinedParallelOp(Fortran::lower::AbstractConverter & converter,Fortran::lower::pft::Evaluation & eval,const Directive & directive)619 createCombinedParallelOp(Fortran::lower::AbstractConverter &converter,
620                          Fortran::lower::pft::Evaluation &eval,
621                          const Directive &directive) {
622   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
623   mlir::Location currentLocation = converter.getCurrentLocation();
624   Fortran::lower::StatementContext stmtCtx;
625   llvm::ArrayRef<mlir::Type> argTy;
626   mlir::Value ifClauseOperand, numThreadsClauseOperand;
627   SmallVector<Value> allocatorOperands, allocateOperands;
628   mlir::omp::ClauseProcBindKindAttr procBindKindAttr;
629   const auto &opClauseList =
630       std::get<Fortran::parser::OmpClauseList>(directive.t);
631   // TODO: Handle the following clauses
632   // 1. default
633   // Note: rest of the clauses are handled when the inner operation is created
634   for (const Fortran::parser::OmpClause &clause : opClauseList.v) {
635     if (const auto &ifClause =
636             std::get_if<Fortran::parser::OmpClause::If>(&clause.u)) {
637       ifClauseOperand = getIfClauseOperand(converter, stmtCtx, ifClause);
638     } else if (const auto &numThreadsClause =
639                    std::get_if<Fortran::parser::OmpClause::NumThreads>(
640                        &clause.u)) {
641       numThreadsClauseOperand = fir::getBase(converter.genExprValue(
642           *Fortran::semantics::GetExpr(numThreadsClause->v), stmtCtx));
643     } else if (const auto &procBindClause =
644                    std::get_if<Fortran::parser::OmpClause::ProcBind>(
645                        &clause.u)) {
646       procBindKindAttr = genProcBindKindAttr(firOpBuilder, procBindClause);
647     }
648   }
649   // Create and insert the operation.
650   auto parallelOp = firOpBuilder.create<mlir::omp::ParallelOp>(
651       currentLocation, argTy, ifClauseOperand, numThreadsClauseOperand,
652       allocateOperands, allocatorOperands, /*reduction_vars=*/ValueRange(),
653       /*reductions=*/nullptr, procBindKindAttr);
654 
655   createBodyOfOp<omp::ParallelOp>(parallelOp, converter, currentLocation, eval,
656                                   &opClauseList, /*iv=*/{},
657                                   /*isCombined=*/true);
658 }
659 
660 static void
genOMP(Fortran::lower::AbstractConverter & converter,Fortran::lower::pft::Evaluation & eval,const Fortran::parser::OpenMPBlockConstruct & blockConstruct)661 genOMP(Fortran::lower::AbstractConverter &converter,
662        Fortran::lower::pft::Evaluation &eval,
663        const Fortran::parser::OpenMPBlockConstruct &blockConstruct) {
664   const auto &beginBlockDirective =
665       std::get<Fortran::parser::OmpBeginBlockDirective>(blockConstruct.t);
666   const auto &blockDirective =
667       std::get<Fortran::parser::OmpBlockDirective>(beginBlockDirective.t);
668   const auto &endBlockDirective =
669       std::get<Fortran::parser::OmpEndBlockDirective>(blockConstruct.t);
670   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
671   mlir::Location currentLocation = converter.getCurrentLocation();
672 
673   Fortran::lower::StatementContext stmtCtx;
674   llvm::ArrayRef<mlir::Type> argTy;
675   mlir::Value ifClauseOperand, numThreadsClauseOperand, finalClauseOperand,
676       priorityClauseOperand;
677   mlir::omp::ClauseProcBindKindAttr procBindKindAttr;
678   SmallVector<Value> allocateOperands, allocatorOperands;
679   mlir::UnitAttr nowaitAttr, untiedAttr, mergeableAttr;
680 
681   const auto &opClauseList =
682       std::get<Fortran::parser::OmpClauseList>(beginBlockDirective.t);
683   for (const auto &clause : opClauseList.v) {
684     if (const auto &ifClause =
685             std::get_if<Fortran::parser::OmpClause::If>(&clause.u)) {
686       ifClauseOperand = getIfClauseOperand(converter, stmtCtx, ifClause);
687     } else if (const auto &numThreadsClause =
688                    std::get_if<Fortran::parser::OmpClause::NumThreads>(
689                        &clause.u)) {
690       // OMPIRBuilder expects `NUM_THREAD` clause as a `Value`.
691       numThreadsClauseOperand = fir::getBase(converter.genExprValue(
692           *Fortran::semantics::GetExpr(numThreadsClause->v), stmtCtx));
693     } else if (const auto &procBindClause =
694                    std::get_if<Fortran::parser::OmpClause::ProcBind>(
695                        &clause.u)) {
696       procBindKindAttr = genProcBindKindAttr(firOpBuilder, procBindClause);
697     } else if (const auto &allocateClause =
698                    std::get_if<Fortran::parser::OmpClause::Allocate>(
699                        &clause.u)) {
700       genAllocateClause(converter, allocateClause->v, allocatorOperands,
701                         allocateOperands);
702     } else if (std::get_if<Fortran::parser::OmpClause::Private>(&clause.u) ||
703                std::get_if<Fortran::parser::OmpClause::Firstprivate>(
704                    &clause.u) ||
705                std::get_if<Fortran::parser::OmpClause::Copyin>(&clause.u)) {
706       // Privatisation and copyin clauses are handled elsewhere.
707       continue;
708     } else if (std::get_if<Fortran::parser::OmpClause::Shared>(&clause.u)) {
709       // Shared is the default behavior in the IR, so no handling is required.
710       continue;
711     } else if (const auto &defaultClause =
712                    std::get_if<Fortran::parser::OmpClause::Default>(
713                        &clause.u)) {
714       if ((defaultClause->v.v ==
715            Fortran::parser::OmpDefaultClause::Type::Shared) ||
716           (defaultClause->v.v ==
717            Fortran::parser::OmpDefaultClause::Type::None)) {
718         // Default clause with shared or none do not require any handling since
719         // Shared is the default behavior in the IR and None is only required
720         // for semantic checks.
721         continue;
722       }
723     } else if (std::get_if<Fortran::parser::OmpClause::Threads>(&clause.u)) {
724       // Nothing needs to be done for threads clause.
725       continue;
726     } else if (const auto &finalClause =
727                    std::get_if<Fortran::parser::OmpClause::Final>(&clause.u)) {
728       mlir::Value finalVal = fir::getBase(converter.genExprValue(
729           *Fortran::semantics::GetExpr(finalClause->v), stmtCtx));
730       finalClauseOperand = firOpBuilder.createConvert(
731           currentLocation, firOpBuilder.getI1Type(), finalVal);
732     } else if (std::get_if<Fortran::parser::OmpClause::Untied>(&clause.u)) {
733       untiedAttr = firOpBuilder.getUnitAttr();
734     } else if (std::get_if<Fortran::parser::OmpClause::Mergeable>(&clause.u)) {
735       mergeableAttr = firOpBuilder.getUnitAttr();
736     } else if (const auto &priorityClause =
737                    std::get_if<Fortran::parser::OmpClause::Priority>(
738                        &clause.u)) {
739       priorityClauseOperand = fir::getBase(converter.genExprValue(
740           *Fortran::semantics::GetExpr(priorityClause->v), stmtCtx));
741     } else {
742       TODO(currentLocation, "OpenMP Block construct clauses");
743     }
744   }
745 
746   for (const auto &clause :
747        std::get<Fortran::parser::OmpClauseList>(endBlockDirective.t).v) {
748     if (std::get_if<Fortran::parser::OmpClause::Nowait>(&clause.u))
749       nowaitAttr = firOpBuilder.getUnitAttr();
750   }
751 
752   if (blockDirective.v == llvm::omp::OMPD_parallel) {
753     // Create and insert the operation.
754     auto parallelOp = firOpBuilder.create<mlir::omp::ParallelOp>(
755         currentLocation, argTy, ifClauseOperand, numThreadsClauseOperand,
756         allocateOperands, allocatorOperands, /*reduction_vars=*/ValueRange(),
757         /*reductions=*/nullptr, procBindKindAttr);
758     createBodyOfOp<omp::ParallelOp>(parallelOp, converter, currentLocation,
759                                     eval, &opClauseList);
760   } else if (blockDirective.v == llvm::omp::OMPD_master) {
761     auto masterOp =
762         firOpBuilder.create<mlir::omp::MasterOp>(currentLocation, argTy);
763     createBodyOfOp<omp::MasterOp>(masterOp, converter, currentLocation, eval);
764   } else if (blockDirective.v == llvm::omp::OMPD_single) {
765     auto singleOp = firOpBuilder.create<mlir::omp::SingleOp>(
766         currentLocation, allocateOperands, allocatorOperands, nowaitAttr);
767     createBodyOfOp<omp::SingleOp>(singleOp, converter, currentLocation, eval);
768   } else if (blockDirective.v == llvm::omp::OMPD_ordered) {
769     auto orderedOp = firOpBuilder.create<mlir::omp::OrderedRegionOp>(
770         currentLocation, /*simd=*/nullptr);
771     createBodyOfOp<omp::OrderedRegionOp>(orderedOp, converter, currentLocation,
772                                          eval);
773   } else if (blockDirective.v == llvm::omp::OMPD_task) {
774     auto taskOp = firOpBuilder.create<mlir::omp::TaskOp>(
775         currentLocation, ifClauseOperand, finalClauseOperand, untiedAttr,
776         mergeableAttr, /*in_reduction_vars=*/ValueRange(),
777         /*in_reductions=*/nullptr, priorityClauseOperand, allocateOperands,
778         allocatorOperands);
779     createBodyOfOp(taskOp, converter, currentLocation, eval, &opClauseList);
780   } else {
781     TODO(converter.getCurrentLocation(), "Unhandled block directive");
782   }
783 }
784 
785 /// Creates an OpenMP reduction declaration and inserts it into the provided
786 /// symbol table. The declaration has a constant initializer with the neutral
787 /// value `initValue`, and the reduction combiner carried over from `reduce`.
788 /// TODO: Generalize this for non-integer types, add atomic region.
createReductionDecl(fir::FirOpBuilder & builder,llvm::StringRef name,mlir::Type type,mlir::Location loc)789 static omp::ReductionDeclareOp createReductionDecl(fir::FirOpBuilder &builder,
790                                                    llvm::StringRef name,
791                                                    mlir::Type type,
792                                                    mlir::Location loc) {
793   OpBuilder::InsertionGuard guard(builder);
794   mlir::ModuleOp module = builder.getModule();
795   mlir::OpBuilder modBuilder(module.getBodyRegion());
796   auto decl = module.lookupSymbol<mlir::omp::ReductionDeclareOp>(name);
797   if (!decl)
798     decl = modBuilder.create<omp::ReductionDeclareOp>(loc, name, type);
799   else
800     return decl;
801 
802   builder.createBlock(&decl.initializerRegion(), decl.initializerRegion().end(),
803                       {type}, {loc});
804   builder.setInsertionPointToEnd(&decl.initializerRegion().back());
805   Value init = builder.create<mlir::arith::ConstantOp>(
806       loc, type, builder.getIntegerAttr(type, 0));
807   builder.create<omp::YieldOp>(loc, init);
808 
809   builder.createBlock(&decl.reductionRegion(), decl.reductionRegion().end(),
810                       {type, type}, {loc, loc});
811   builder.setInsertionPointToEnd(&decl.reductionRegion().back());
812   mlir::Value op1 = decl.reductionRegion().front().getArgument(0);
813   mlir::Value op2 = decl.reductionRegion().front().getArgument(1);
814   Value addRes = builder.create<mlir::arith::AddIOp>(loc, op1, op2);
815   builder.create<omp::YieldOp>(loc, addRes);
816   return decl;
817 }
818 
819 static mlir::omp::ScheduleModifier
translateModifier(const Fortran::parser::OmpScheduleModifierType & m)820 translateModifier(const Fortran::parser::OmpScheduleModifierType &m) {
821   switch (m.v) {
822   case Fortran::parser::OmpScheduleModifierType::ModType::Monotonic:
823     return mlir::omp::ScheduleModifier::monotonic;
824   case Fortran::parser::OmpScheduleModifierType::ModType::Nonmonotonic:
825     return mlir::omp::ScheduleModifier::nonmonotonic;
826   case Fortran::parser::OmpScheduleModifierType::ModType::Simd:
827     return mlir::omp::ScheduleModifier::simd;
828   }
829   return mlir::omp::ScheduleModifier::none;
830 }
831 
832 static mlir::omp::ScheduleModifier
getScheduleModifier(const Fortran::parser::OmpScheduleClause & x)833 getScheduleModifier(const Fortran::parser::OmpScheduleClause &x) {
834   const auto &modifier =
835       std::get<std::optional<Fortran::parser::OmpScheduleModifier>>(x.t);
836   // The input may have the modifier any order, so we look for one that isn't
837   // SIMD. If modifier is not set at all, fall down to the bottom and return
838   // "none".
839   if (modifier) {
840     const auto &modType1 =
841         std::get<Fortran::parser::OmpScheduleModifier::Modifier1>(modifier->t);
842     if (modType1.v.v ==
843         Fortran::parser::OmpScheduleModifierType::ModType::Simd) {
844       const auto &modType2 = std::get<
845           std::optional<Fortran::parser::OmpScheduleModifier::Modifier2>>(
846           modifier->t);
847       if (modType2 &&
848           modType2->v.v !=
849               Fortran::parser::OmpScheduleModifierType::ModType::Simd)
850         return translateModifier(modType2->v);
851 
852       return mlir::omp::ScheduleModifier::none;
853     }
854 
855     return translateModifier(modType1.v);
856   }
857   return mlir::omp::ScheduleModifier::none;
858 }
859 
860 static mlir::omp::ScheduleModifier
getSIMDModifier(const Fortran::parser::OmpScheduleClause & x)861 getSIMDModifier(const Fortran::parser::OmpScheduleClause &x) {
862   const auto &modifier =
863       std::get<std::optional<Fortran::parser::OmpScheduleModifier>>(x.t);
864   // Either of the two possible modifiers in the input can be the SIMD modifier,
865   // so look in either one, and return simd if we find one. Not found = return
866   // "none".
867   if (modifier) {
868     const auto &modType1 =
869         std::get<Fortran::parser::OmpScheduleModifier::Modifier1>(modifier->t);
870     if (modType1.v.v == Fortran::parser::OmpScheduleModifierType::ModType::Simd)
871       return mlir::omp::ScheduleModifier::simd;
872 
873     const auto &modType2 = std::get<
874         std::optional<Fortran::parser::OmpScheduleModifier::Modifier2>>(
875         modifier->t);
876     if (modType2 && modType2->v.v ==
877                         Fortran::parser::OmpScheduleModifierType::ModType::Simd)
878       return mlir::omp::ScheduleModifier::simd;
879   }
880   return mlir::omp::ScheduleModifier::none;
881 }
882 
getReductionName(Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,mlir::Type ty)883 static std::string getReductionName(
884     Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
885     mlir::Type ty) {
886   std::string reductionName;
887   if (intrinsicOp == Fortran::parser::DefinedOperator::IntrinsicOperator::Add)
888     reductionName = "add_reduction";
889   else
890     reductionName = "other_reduction";
891 
892   return (llvm::Twine(reductionName) +
893           (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) +
894           llvm::Twine(ty.getIntOrFloatBitWidth()))
895       .str();
896 }
897 
genOMP(Fortran::lower::AbstractConverter & converter,Fortran::lower::pft::Evaluation & eval,const Fortran::parser::OpenMPLoopConstruct & loopConstruct)898 static void genOMP(Fortran::lower::AbstractConverter &converter,
899                    Fortran::lower::pft::Evaluation &eval,
900                    const Fortran::parser::OpenMPLoopConstruct &loopConstruct) {
901 
902   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
903   mlir::Location currentLocation = converter.getCurrentLocation();
904   llvm::SmallVector<mlir::Value> lowerBound, upperBound, step, linearVars,
905       linearStepVars, reductionVars;
906   mlir::Value scheduleChunkClauseOperand, ifClauseOperand;
907   mlir::Attribute scheduleClauseOperand, noWaitClauseOperand,
908       orderedClauseOperand, orderClauseOperand;
909   SmallVector<Attribute> reductionDeclSymbols;
910   Fortran::lower::StatementContext stmtCtx;
911   const auto &loopOpClauseList = std::get<Fortran::parser::OmpClauseList>(
912       std::get<Fortran::parser::OmpBeginLoopDirective>(loopConstruct.t).t);
913 
914   const auto ompDirective =
915       std::get<Fortran::parser::OmpLoopDirective>(
916           std::get<Fortran::parser::OmpBeginLoopDirective>(loopConstruct.t).t)
917           .v;
918   if (llvm::omp::OMPD_parallel_do == ompDirective) {
919     createCombinedParallelOp<Fortran::parser::OmpBeginLoopDirective>(
920         converter, eval,
921         std::get<Fortran::parser::OmpBeginLoopDirective>(loopConstruct.t));
922   } else if (llvm::omp::OMPD_do != ompDirective &&
923              llvm::omp::OMPD_simd != ompDirective) {
924     TODO(converter.getCurrentLocation(), "Construct enclosing do loop");
925   }
926 
927   // Collect the loops to collapse.
928   auto *doConstructEval = &eval.getFirstNestedEvaluation();
929 
930   std::int64_t collapseValue =
931       Fortran::lower::getCollapseValue(loopOpClauseList);
932   std::size_t loopVarTypeSize = 0;
933   SmallVector<const Fortran::semantics::Symbol *> iv;
934   do {
935     auto *doLoop = &doConstructEval->getFirstNestedEvaluation();
936     auto *doStmt = doLoop->getIf<Fortran::parser::NonLabelDoStmt>();
937     assert(doStmt && "Expected do loop to be in the nested evaluation");
938     const auto &loopControl =
939         std::get<std::optional<Fortran::parser::LoopControl>>(doStmt->t);
940     const Fortran::parser::LoopControl::Bounds *bounds =
941         std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
942     assert(bounds && "Expected bounds for worksharing do loop");
943     Fortran::lower::StatementContext stmtCtx;
944     lowerBound.push_back(fir::getBase(converter.genExprValue(
945         *Fortran::semantics::GetExpr(bounds->lower), stmtCtx)));
946     upperBound.push_back(fir::getBase(converter.genExprValue(
947         *Fortran::semantics::GetExpr(bounds->upper), stmtCtx)));
948     if (bounds->step) {
949       step.push_back(fir::getBase(converter.genExprValue(
950           *Fortran::semantics::GetExpr(bounds->step), stmtCtx)));
951     } else { // If `step` is not present, assume it as `1`.
952       step.push_back(firOpBuilder.createIntegerConstant(
953           currentLocation, firOpBuilder.getIntegerType(32), 1));
954     }
955     iv.push_back(bounds->name.thing.symbol);
956     loopVarTypeSize = std::max(loopVarTypeSize,
957                                bounds->name.thing.symbol->GetUltimate().size());
958 
959     collapseValue--;
960     doConstructEval =
961         &*std::next(doConstructEval->getNestedEvaluations().begin());
962   } while (collapseValue > 0);
963 
964   for (const auto &clause : loopOpClauseList.v) {
965     if (const auto &scheduleClause =
966             std::get_if<Fortran::parser::OmpClause::Schedule>(&clause.u)) {
967       if (const auto &chunkExpr =
968               std::get<std::optional<Fortran::parser::ScalarIntExpr>>(
969                   scheduleClause->v.t)) {
970         if (const auto *expr = Fortran::semantics::GetExpr(*chunkExpr)) {
971           scheduleChunkClauseOperand =
972               fir::getBase(converter.genExprValue(*expr, stmtCtx));
973         }
974       }
975     } else if (const auto &ifClause =
976                    std::get_if<Fortran::parser::OmpClause::If>(&clause.u)) {
977       ifClauseOperand = getIfClauseOperand(converter, stmtCtx, ifClause);
978     } else if (const auto &reductionClause =
979                    std::get_if<Fortran::parser::OmpClause::Reduction>(
980                        &clause.u)) {
981       omp::ReductionDeclareOp decl;
982       const auto &redOperator{std::get<Fortran::parser::OmpReductionOperator>(
983           reductionClause->v.t)};
984       const auto &objectList{
985           std::get<Fortran::parser::OmpObjectList>(reductionClause->v.t)};
986       if (const auto &redDefinedOp =
987               std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
988         const auto &intrinsicOp{
989             std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>(
990                 redDefinedOp->u)};
991         if (intrinsicOp !=
992             Fortran::parser::DefinedOperator::IntrinsicOperator::Add)
993           TODO(currentLocation,
994                "Reduction of some intrinsic operators is not supported");
995         for (const auto &ompObject : objectList.v) {
996           if (const auto *name{
997                   Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
998             if (const auto *symbol{name->symbol}) {
999               mlir::Value symVal = converter.getSymbolAddress(*symbol);
1000               mlir::Type redType =
1001                   symVal.getType().cast<fir::ReferenceType>().getEleTy();
1002               reductionVars.push_back(symVal);
1003               if (redType.isIntOrIndex()) {
1004                 decl = createReductionDecl(
1005                     firOpBuilder, getReductionName(intrinsicOp, redType),
1006                     redType, currentLocation);
1007               } else {
1008                 TODO(currentLocation,
1009                      "Reduction of some types is not supported");
1010               }
1011               reductionDeclSymbols.push_back(SymbolRefAttr::get(
1012                   firOpBuilder.getContext(), decl.sym_name()));
1013             }
1014           }
1015         }
1016       } else {
1017         TODO(currentLocation,
1018              "Reduction of intrinsic procedures is not supported");
1019       }
1020     }
1021   }
1022 
1023   // The types of lower bound, upper bound, and step are converted into the
1024   // type of the loop variable if necessary.
1025   mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
1026   for (unsigned it = 0; it < (unsigned)lowerBound.size(); it++) {
1027     lowerBound[it] = firOpBuilder.createConvert(currentLocation, loopVarType,
1028                                                 lowerBound[it]);
1029     upperBound[it] = firOpBuilder.createConvert(currentLocation, loopVarType,
1030                                                 upperBound[it]);
1031     step[it] =
1032         firOpBuilder.createConvert(currentLocation, loopVarType, step[it]);
1033   }
1034 
1035   // 2.9.3.1 SIMD construct
1036   // TODO: Support all the clauses
1037   if (llvm::omp::OMPD_simd == ompDirective) {
1038     TypeRange resultType;
1039     auto SimdLoopOp = firOpBuilder.create<mlir::omp::SimdLoopOp>(
1040         currentLocation, resultType, lowerBound, upperBound, step,
1041         ifClauseOperand, /*inclusive=*/firOpBuilder.getUnitAttr());
1042     createBodyOfOp<omp::SimdLoopOp>(SimdLoopOp, converter, currentLocation,
1043                                     eval, &loopOpClauseList, iv);
1044     return;
1045   }
1046 
1047   // FIXME: Add support for following clauses:
1048   // 1. linear
1049   // 2. order
1050   auto wsLoopOp = firOpBuilder.create<mlir::omp::WsLoopOp>(
1051       currentLocation, lowerBound, upperBound, step, linearVars, linearStepVars,
1052       reductionVars,
1053       reductionDeclSymbols.empty()
1054           ? nullptr
1055           : mlir::ArrayAttr::get(firOpBuilder.getContext(),
1056                                  reductionDeclSymbols),
1057       scheduleClauseOperand.dyn_cast_or_null<omp::ClauseScheduleKindAttr>(),
1058       scheduleChunkClauseOperand, /*schedule_modifiers=*/nullptr,
1059       /*simd_modifier=*/nullptr,
1060       noWaitClauseOperand.dyn_cast_or_null<UnitAttr>(),
1061       orderedClauseOperand.dyn_cast_or_null<IntegerAttr>(),
1062       orderClauseOperand.dyn_cast_or_null<omp::ClauseOrderKindAttr>(),
1063       /*inclusive=*/firOpBuilder.getUnitAttr());
1064 
1065   // Handle attribute based clauses.
1066   for (const Fortran::parser::OmpClause &clause : loopOpClauseList.v) {
1067     if (const auto &orderedClause =
1068             std::get_if<Fortran::parser::OmpClause::Ordered>(&clause.u)) {
1069       if (orderedClause->v.has_value()) {
1070         const auto *expr = Fortran::semantics::GetExpr(orderedClause->v);
1071         const std::optional<std::int64_t> orderedClauseValue =
1072             Fortran::evaluate::ToInt64(*expr);
1073         wsLoopOp.ordered_valAttr(
1074             firOpBuilder.getI64IntegerAttr(*orderedClauseValue));
1075       } else {
1076         wsLoopOp.ordered_valAttr(firOpBuilder.getI64IntegerAttr(0));
1077       }
1078     } else if (const auto &scheduleClause =
1079                    std::get_if<Fortran::parser::OmpClause::Schedule>(
1080                        &clause.u)) {
1081       mlir::MLIRContext *context = firOpBuilder.getContext();
1082       const auto &scheduleType = scheduleClause->v;
1083       const auto &scheduleKind =
1084           std::get<Fortran::parser::OmpScheduleClause::ScheduleType>(
1085               scheduleType.t);
1086       switch (scheduleKind) {
1087       case Fortran::parser::OmpScheduleClause::ScheduleType::Static:
1088         wsLoopOp.schedule_valAttr(omp::ClauseScheduleKindAttr::get(
1089             context, omp::ClauseScheduleKind::Static));
1090         break;
1091       case Fortran::parser::OmpScheduleClause::ScheduleType::Dynamic:
1092         wsLoopOp.schedule_valAttr(omp::ClauseScheduleKindAttr::get(
1093             context, omp::ClauseScheduleKind::Dynamic));
1094         break;
1095       case Fortran::parser::OmpScheduleClause::ScheduleType::Guided:
1096         wsLoopOp.schedule_valAttr(omp::ClauseScheduleKindAttr::get(
1097             context, omp::ClauseScheduleKind::Guided));
1098         break;
1099       case Fortran::parser::OmpScheduleClause::ScheduleType::Auto:
1100         wsLoopOp.schedule_valAttr(omp::ClauseScheduleKindAttr::get(
1101             context, omp::ClauseScheduleKind::Auto));
1102         break;
1103       case Fortran::parser::OmpScheduleClause::ScheduleType::Runtime:
1104         wsLoopOp.schedule_valAttr(omp::ClauseScheduleKindAttr::get(
1105             context, omp::ClauseScheduleKind::Runtime));
1106         break;
1107       }
1108       mlir::omp::ScheduleModifier scheduleModifier =
1109           getScheduleModifier(scheduleClause->v);
1110       if (scheduleModifier != mlir::omp::ScheduleModifier::none)
1111         wsLoopOp.schedule_modifierAttr(
1112             omp::ScheduleModifierAttr::get(context, scheduleModifier));
1113       if (getSIMDModifier(scheduleClause->v) !=
1114           mlir::omp::ScheduleModifier::none)
1115         wsLoopOp.simd_modifierAttr(firOpBuilder.getUnitAttr());
1116     }
1117   }
1118   // In FORTRAN `nowait` clause occur at the end of `omp do` directive.
1119   // i.e
1120   // !$omp do
1121   // <...>
1122   // !$omp end do nowait
1123   if (const auto &endClauseList =
1124           std::get<std::optional<Fortran::parser::OmpEndLoopDirective>>(
1125               loopConstruct.t)) {
1126     const auto &clauseList =
1127         std::get<Fortran::parser::OmpClauseList>((*endClauseList).t);
1128     for (const Fortran::parser::OmpClause &clause : clauseList.v)
1129       if (std::get_if<Fortran::parser::OmpClause::Nowait>(&clause.u))
1130         wsLoopOp.nowaitAttr(firOpBuilder.getUnitAttr());
1131   }
1132 
1133   createBodyOfOp<omp::WsLoopOp>(wsLoopOp, converter, currentLocation, eval,
1134                                 &loopOpClauseList, iv);
1135 }
1136 
1137 static void
genOMP(Fortran::lower::AbstractConverter & converter,Fortran::lower::pft::Evaluation & eval,const Fortran::parser::OpenMPCriticalConstruct & criticalConstruct)1138 genOMP(Fortran::lower::AbstractConverter &converter,
1139        Fortran::lower::pft::Evaluation &eval,
1140        const Fortran::parser::OpenMPCriticalConstruct &criticalConstruct) {
1141   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
1142   mlir::Location currentLocation = converter.getCurrentLocation();
1143   std::string name;
1144   const Fortran::parser::OmpCriticalDirective &cd =
1145       std::get<Fortran::parser::OmpCriticalDirective>(criticalConstruct.t);
1146   if (std::get<std::optional<Fortran::parser::Name>>(cd.t).has_value()) {
1147     name =
1148         std::get<std::optional<Fortran::parser::Name>>(cd.t).value().ToString();
1149   }
1150 
1151   uint64_t hint = 0;
1152   const auto &clauseList = std::get<Fortran::parser::OmpClauseList>(cd.t);
1153   for (const Fortran::parser::OmpClause &clause : clauseList.v)
1154     if (auto hintClause =
1155             std::get_if<Fortran::parser::OmpClause::Hint>(&clause.u)) {
1156       const auto *expr = Fortran::semantics::GetExpr(hintClause->v);
1157       hint = *Fortran::evaluate::ToInt64(*expr);
1158       break;
1159     }
1160 
1161   mlir::omp::CriticalOp criticalOp = [&]() {
1162     if (name.empty()) {
1163       return firOpBuilder.create<mlir::omp::CriticalOp>(currentLocation,
1164                                                         FlatSymbolRefAttr());
1165     } else {
1166       mlir::ModuleOp module = firOpBuilder.getModule();
1167       mlir::OpBuilder modBuilder(module.getBodyRegion());
1168       auto global = module.lookupSymbol<mlir::omp::CriticalDeclareOp>(name);
1169       if (!global)
1170         global = modBuilder.create<mlir::omp::CriticalDeclareOp>(
1171             currentLocation, name, hint);
1172       return firOpBuilder.create<mlir::omp::CriticalOp>(
1173           currentLocation, mlir::FlatSymbolRefAttr::get(
1174                                firOpBuilder.getContext(), global.sym_name()));
1175     }
1176   }();
1177   createBodyOfOp<omp::CriticalOp>(criticalOp, converter, currentLocation, eval);
1178 }
1179 
1180 static void
genOMP(Fortran::lower::AbstractConverter & converter,Fortran::lower::pft::Evaluation & eval,const Fortran::parser::OpenMPSectionConstruct & sectionConstruct)1181 genOMP(Fortran::lower::AbstractConverter &converter,
1182        Fortran::lower::pft::Evaluation &eval,
1183        const Fortran::parser::OpenMPSectionConstruct &sectionConstruct) {
1184 
1185   auto &firOpBuilder = converter.getFirOpBuilder();
1186   auto currentLocation = converter.getCurrentLocation();
1187   mlir::omp::SectionOp sectionOp =
1188       firOpBuilder.create<mlir::omp::SectionOp>(currentLocation);
1189   createBodyOfOp<omp::SectionOp>(sectionOp, converter, currentLocation, eval);
1190 }
1191 
1192 // TODO: Add support for reduction
1193 static void
genOMP(Fortran::lower::AbstractConverter & converter,Fortran::lower::pft::Evaluation & eval,const Fortran::parser::OpenMPSectionsConstruct & sectionsConstruct)1194 genOMP(Fortran::lower::AbstractConverter &converter,
1195        Fortran::lower::pft::Evaluation &eval,
1196        const Fortran::parser::OpenMPSectionsConstruct &sectionsConstruct) {
1197   auto &firOpBuilder = converter.getFirOpBuilder();
1198   auto currentLocation = converter.getCurrentLocation();
1199   SmallVector<Value> reductionVars, allocateOperands, allocatorOperands;
1200   mlir::UnitAttr noWaitClauseOperand;
1201   const auto &sectionsClauseList = std::get<Fortran::parser::OmpClauseList>(
1202       std::get<Fortran::parser::OmpBeginSectionsDirective>(sectionsConstruct.t)
1203           .t);
1204   for (const Fortran::parser::OmpClause &clause : sectionsClauseList.v) {
1205 
1206     // Reduction Clause
1207     if (std::get_if<Fortran::parser::OmpClause::Reduction>(&clause.u)) {
1208       TODO(currentLocation, "OMPC_Reduction");
1209 
1210       // Allocate clause
1211     } else if (const auto &allocateClause =
1212                    std::get_if<Fortran::parser::OmpClause::Allocate>(
1213                        &clause.u)) {
1214       genAllocateClause(converter, allocateClause->v, allocatorOperands,
1215                         allocateOperands);
1216     }
1217   }
1218   const auto &endSectionsClauseList =
1219       std::get<Fortran::parser::OmpEndSectionsDirective>(sectionsConstruct.t);
1220   const auto &clauseList =
1221       std::get<Fortran::parser::OmpClauseList>(endSectionsClauseList.t);
1222   for (const auto &clause : clauseList.v) {
1223     // Nowait clause
1224     if (std::get_if<Fortran::parser::OmpClause::Nowait>(&clause.u)) {
1225       noWaitClauseOperand = firOpBuilder.getUnitAttr();
1226     }
1227   }
1228 
1229   llvm::omp::Directive dir =
1230       std::get<Fortran::parser::OmpSectionsDirective>(
1231           std::get<Fortran::parser::OmpBeginSectionsDirective>(
1232               sectionsConstruct.t)
1233               .t)
1234           .v;
1235 
1236   // Parallel Sections Construct
1237   if (dir == llvm::omp::Directive::OMPD_parallel_sections) {
1238     createCombinedParallelOp<Fortran::parser::OmpBeginSectionsDirective>(
1239         converter, eval,
1240         std::get<Fortran::parser::OmpBeginSectionsDirective>(
1241             sectionsConstruct.t));
1242     auto sectionsOp = firOpBuilder.create<mlir::omp::SectionsOp>(
1243         currentLocation, /*reduction_vars*/ ValueRange(),
1244         /*reductions=*/nullptr, allocateOperands, allocatorOperands,
1245         /*nowait=*/nullptr);
1246     createBodyOfOp(sectionsOp, converter, currentLocation, eval);
1247 
1248     // Sections Construct
1249   } else if (dir == llvm::omp::Directive::OMPD_sections) {
1250     auto sectionsOp = firOpBuilder.create<mlir::omp::SectionsOp>(
1251         currentLocation, reductionVars, /*reductions = */ nullptr,
1252         allocateOperands, allocatorOperands, noWaitClauseOperand);
1253     createBodyOfOp<omp::SectionsOp>(sectionsOp, converter, currentLocation,
1254                                     eval);
1255   }
1256 }
1257 
genOmpAtomicHintAndMemoryOrderClauses(Fortran::lower::AbstractConverter & converter,const Fortran::parser::OmpAtomicClauseList & clauseList,mlir::IntegerAttr & hint,mlir::omp::ClauseMemoryOrderKindAttr & memory_order)1258 static void genOmpAtomicHintAndMemoryOrderClauses(
1259     Fortran::lower::AbstractConverter &converter,
1260     const Fortran::parser::OmpAtomicClauseList &clauseList,
1261     mlir::IntegerAttr &hint,
1262     mlir::omp::ClauseMemoryOrderKindAttr &memory_order) {
1263   auto &firOpBuilder = converter.getFirOpBuilder();
1264   for (const auto &clause : clauseList.v) {
1265     if (auto ompClause = std::get_if<Fortran::parser::OmpClause>(&clause.u)) {
1266       if (auto hintClause =
1267               std::get_if<Fortran::parser::OmpClause::Hint>(&ompClause->u)) {
1268         const auto *expr = Fortran::semantics::GetExpr(hintClause->v);
1269         uint64_t hintExprValue = *Fortran::evaluate::ToInt64(*expr);
1270         hint = firOpBuilder.getI64IntegerAttr(hintExprValue);
1271       }
1272     } else if (auto ompMemoryOrderClause =
1273                    std::get_if<Fortran::parser::OmpMemoryOrderClause>(
1274                        &clause.u)) {
1275       if (std::get_if<Fortran::parser::OmpClause::Acquire>(
1276               &ompMemoryOrderClause->v.u)) {
1277         memory_order = mlir::omp::ClauseMemoryOrderKindAttr::get(
1278             firOpBuilder.getContext(), omp::ClauseMemoryOrderKind::Acquire);
1279       } else if (std::get_if<Fortran::parser::OmpClause::Relaxed>(
1280                      &ompMemoryOrderClause->v.u)) {
1281         memory_order = mlir::omp::ClauseMemoryOrderKindAttr::get(
1282             firOpBuilder.getContext(), omp::ClauseMemoryOrderKind::Relaxed);
1283       } else if (std::get_if<Fortran::parser::OmpClause::SeqCst>(
1284                      &ompMemoryOrderClause->v.u)) {
1285         memory_order = mlir::omp::ClauseMemoryOrderKindAttr::get(
1286             firOpBuilder.getContext(), omp::ClauseMemoryOrderKind::Seq_cst);
1287       } else if (std::get_if<Fortran::parser::OmpClause::Release>(
1288                      &ompMemoryOrderClause->v.u)) {
1289         memory_order = mlir::omp::ClauseMemoryOrderKindAttr::get(
1290             firOpBuilder.getContext(), omp::ClauseMemoryOrderKind::Release);
1291       }
1292     }
1293   }
1294 }
1295 
genOmpAtomicUpdateStatement(Fortran::lower::AbstractConverter & converter,Fortran::lower::pft::Evaluation & eval,const Fortran::parser::Variable & assignmentStmtVariable,const Fortran::parser::Expr & assignmentStmtExpr,const Fortran::parser::OmpAtomicClauseList * leftHandClauseList,const Fortran::parser::OmpAtomicClauseList * rightHandClauseList)1296 static void genOmpAtomicUpdateStatement(
1297     Fortran::lower::AbstractConverter &converter,
1298     Fortran::lower::pft::Evaluation &eval,
1299     const Fortran::parser::Variable &assignmentStmtVariable,
1300     const Fortran::parser::Expr &assignmentStmtExpr,
1301     const Fortran::parser::OmpAtomicClauseList *leftHandClauseList,
1302     const Fortran::parser::OmpAtomicClauseList *rightHandClauseList) {
1303   // Generate `omp.atomic.update` operation for atomic assignment statements
1304   auto &firOpBuilder = converter.getFirOpBuilder();
1305   auto currentLocation = converter.getCurrentLocation();
1306   Fortran::lower::StatementContext stmtCtx;
1307 
1308   mlir::Value address = fir::getBase(converter.genExprAddr(
1309       *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx));
1310   // If no hint clause is specified, the effect is as if
1311   // hint(omp_sync_hint_none) had been specified.
1312   mlir::IntegerAttr hint = nullptr;
1313   mlir::omp::ClauseMemoryOrderKindAttr memory_order = nullptr;
1314   if (leftHandClauseList)
1315     genOmpAtomicHintAndMemoryOrderClauses(converter, *leftHandClauseList, hint,
1316                                           memory_order);
1317   if (rightHandClauseList)
1318     genOmpAtomicHintAndMemoryOrderClauses(converter, *rightHandClauseList, hint,
1319                                           memory_order);
1320   auto atomicUpdateOp = firOpBuilder.create<mlir::omp::AtomicUpdateOp>(
1321       currentLocation, address, hint, memory_order);
1322 
1323   //// Generate body of Atomic Update operation
1324   // If an argument for the region is provided then create the block with that
1325   // argument. Also update the symbol's address with the argument mlir value.
1326   mlir::Type varType =
1327       fir::getBase(
1328           converter.genExprValue(
1329               *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx))
1330           .getType();
1331   SmallVector<Type> varTys = {varType};
1332   SmallVector<Location> locs = {currentLocation};
1333   firOpBuilder.createBlock(&atomicUpdateOp.getRegion(), {}, varTys, locs);
1334   mlir::Value val =
1335       fir::getBase(atomicUpdateOp.getRegion().front().getArgument(0));
1336   auto varDesignator =
1337       std::get_if<Fortran::common::Indirection<Fortran::parser::Designator>>(
1338           &assignmentStmtVariable.u);
1339   assert(varDesignator && "Variable designator for atomic update assignment "
1340                           "statement does not exist");
1341   const auto *name = getDesignatorNameIfDataRef(varDesignator->value());
1342   assert(name && name->symbol &&
1343          "No symbol attached to atomic update variable");
1344   converter.bindSymbol(*name->symbol, val);
1345   // Set the insert for the terminator operation to go at the end of the
1346   // block.
1347   mlir::Block &block = atomicUpdateOp.getRegion().back();
1348   firOpBuilder.setInsertionPointToEnd(&block);
1349 
1350   mlir::Value result = fir::getBase(converter.genExprValue(
1351       *Fortran::semantics::GetExpr(assignmentStmtExpr), stmtCtx));
1352   // Insert the terminator: YieldOp.
1353   firOpBuilder.create<mlir::omp::YieldOp>(currentLocation, result);
1354   // Reset the insert point to before the terminator.
1355   firOpBuilder.setInsertionPointToStart(&block);
1356 }
1357 
1358 static void
genOmpAtomicWrite(Fortran::lower::AbstractConverter & converter,Fortran::lower::pft::Evaluation & eval,const Fortran::parser::OmpAtomicWrite & atomicWrite)1359 genOmpAtomicWrite(Fortran::lower::AbstractConverter &converter,
1360                   Fortran::lower::pft::Evaluation &eval,
1361                   const Fortran::parser::OmpAtomicWrite &atomicWrite) {
1362   auto &firOpBuilder = converter.getFirOpBuilder();
1363   auto currentLocation = converter.getCurrentLocation();
1364   // Get the value and address of atomic write operands.
1365   const Fortran::parser::OmpAtomicClauseList &rightHandClauseList =
1366       std::get<2>(atomicWrite.t);
1367   const Fortran::parser::OmpAtomicClauseList &leftHandClauseList =
1368       std::get<0>(atomicWrite.t);
1369   const auto &assignmentStmtExpr =
1370       std::get<Fortran::parser::Expr>(std::get<3>(atomicWrite.t).statement.t);
1371   const auto &assignmentStmtVariable = std::get<Fortran::parser::Variable>(
1372       std::get<3>(atomicWrite.t).statement.t);
1373   Fortran::lower::StatementContext stmtCtx;
1374   mlir::Value value = fir::getBase(converter.genExprValue(
1375       *Fortran::semantics::GetExpr(assignmentStmtExpr), stmtCtx));
1376   mlir::Value address = fir::getBase(converter.genExprAddr(
1377       *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx));
1378   // If no hint clause is specified, the effect is as if
1379   // hint(omp_sync_hint_none) had been specified.
1380   mlir::IntegerAttr hint = nullptr;
1381   mlir::omp::ClauseMemoryOrderKindAttr memory_order = nullptr;
1382   genOmpAtomicHintAndMemoryOrderClauses(converter, leftHandClauseList, hint,
1383                                         memory_order);
1384   genOmpAtomicHintAndMemoryOrderClauses(converter, rightHandClauseList, hint,
1385                                         memory_order);
1386   firOpBuilder.create<mlir::omp::AtomicWriteOp>(currentLocation, address, value,
1387                                                 hint, memory_order);
1388 }
1389 
genOmpAtomicRead(Fortran::lower::AbstractConverter & converter,Fortran::lower::pft::Evaluation & eval,const Fortran::parser::OmpAtomicRead & atomicRead)1390 static void genOmpAtomicRead(Fortran::lower::AbstractConverter &converter,
1391                              Fortran::lower::pft::Evaluation &eval,
1392                              const Fortran::parser::OmpAtomicRead &atomicRead) {
1393   auto &firOpBuilder = converter.getFirOpBuilder();
1394   auto currentLocation = converter.getCurrentLocation();
1395   // Get the address of atomic read operands.
1396   const Fortran::parser::OmpAtomicClauseList &rightHandClauseList =
1397       std::get<2>(atomicRead.t);
1398   const Fortran::parser::OmpAtomicClauseList &leftHandClauseList =
1399       std::get<0>(atomicRead.t);
1400   const auto &assignmentStmtExpr =
1401       std::get<Fortran::parser::Expr>(std::get<3>(atomicRead.t).statement.t);
1402   const auto &assignmentStmtVariable = std::get<Fortran::parser::Variable>(
1403       std::get<3>(atomicRead.t).statement.t);
1404   Fortran::lower::StatementContext stmtCtx;
1405   mlir::Value from_address = fir::getBase(converter.genExprAddr(
1406       *Fortran::semantics::GetExpr(assignmentStmtExpr), stmtCtx));
1407   mlir::Value to_address = fir::getBase(converter.genExprAddr(
1408       *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx));
1409   // If no hint clause is specified, the effect is as if
1410   // hint(omp_sync_hint_none) had been specified.
1411   mlir::IntegerAttr hint = nullptr;
1412   mlir::omp::ClauseMemoryOrderKindAttr memory_order = nullptr;
1413   genOmpAtomicHintAndMemoryOrderClauses(converter, leftHandClauseList, hint,
1414                                         memory_order);
1415   genOmpAtomicHintAndMemoryOrderClauses(converter, rightHandClauseList, hint,
1416                                         memory_order);
1417   firOpBuilder.create<mlir::omp::AtomicReadOp>(currentLocation, from_address,
1418                                                to_address, hint, memory_order);
1419 }
1420 
1421 static void
genOmpAtomicUpdate(Fortran::lower::AbstractConverter & converter,Fortran::lower::pft::Evaluation & eval,const Fortran::parser::OmpAtomicUpdate & atomicUpdate)1422 genOmpAtomicUpdate(Fortran::lower::AbstractConverter &converter,
1423                    Fortran::lower::pft::Evaluation &eval,
1424                    const Fortran::parser::OmpAtomicUpdate &atomicUpdate) {
1425   const Fortran::parser::OmpAtomicClauseList &rightHandClauseList =
1426       std::get<2>(atomicUpdate.t);
1427   const Fortran::parser::OmpAtomicClauseList &leftHandClauseList =
1428       std::get<0>(atomicUpdate.t);
1429   const auto &assignmentStmtExpr =
1430       std::get<Fortran::parser::Expr>(std::get<3>(atomicUpdate.t).statement.t);
1431   const auto &assignmentStmtVariable = std::get<Fortran::parser::Variable>(
1432       std::get<3>(atomicUpdate.t).statement.t);
1433 
1434   genOmpAtomicUpdateStatement(converter, eval, assignmentStmtVariable,
1435                               assignmentStmtExpr, &leftHandClauseList,
1436                               &rightHandClauseList);
1437 }
1438 
genOmpAtomic(Fortran::lower::AbstractConverter & converter,Fortran::lower::pft::Evaluation & eval,const Fortran::parser::OmpAtomic & atomicConstruct)1439 static void genOmpAtomic(Fortran::lower::AbstractConverter &converter,
1440                          Fortran::lower::pft::Evaluation &eval,
1441                          const Fortran::parser::OmpAtomic &atomicConstruct) {
1442   const Fortran::parser::OmpAtomicClauseList &atomicClauseList =
1443       std::get<Fortran::parser::OmpAtomicClauseList>(atomicConstruct.t);
1444   const auto &assignmentStmtExpr = std::get<Fortran::parser::Expr>(
1445       std::get<Fortran::parser::Statement<Fortran::parser::AssignmentStmt>>(
1446           atomicConstruct.t)
1447           .statement.t);
1448   const auto &assignmentStmtVariable = std::get<Fortran::parser::Variable>(
1449       std::get<Fortran::parser::Statement<Fortran::parser::AssignmentStmt>>(
1450           atomicConstruct.t)
1451           .statement.t);
1452   // If atomic-clause is not present on the construct, the behaviour is as if
1453   // the update clause is specified
1454   genOmpAtomicUpdateStatement(converter, eval, assignmentStmtVariable,
1455                               assignmentStmtExpr, &atomicClauseList, nullptr);
1456 }
1457 
1458 static void
genOMP(Fortran::lower::AbstractConverter & converter,Fortran::lower::pft::Evaluation & eval,const Fortran::parser::OpenMPAtomicConstruct & atomicConstruct)1459 genOMP(Fortran::lower::AbstractConverter &converter,
1460        Fortran::lower::pft::Evaluation &eval,
1461        const Fortran::parser::OpenMPAtomicConstruct &atomicConstruct) {
1462   std::visit(Fortran::common::visitors{
1463                  [&](const Fortran::parser::OmpAtomicRead &atomicRead) {
1464                    genOmpAtomicRead(converter, eval, atomicRead);
1465                  },
1466                  [&](const Fortran::parser::OmpAtomicWrite &atomicWrite) {
1467                    genOmpAtomicWrite(converter, eval, atomicWrite);
1468                  },
1469                  [&](const Fortran::parser::OmpAtomic &atomicConstruct) {
1470                    genOmpAtomic(converter, eval, atomicConstruct);
1471                  },
1472                  [&](const Fortran::parser::OmpAtomicUpdate &atomicUpdate) {
1473                    genOmpAtomicUpdate(converter, eval, atomicUpdate);
1474                  },
1475                  [&](const auto &) {
1476                    TODO(converter.getCurrentLocation(), "Atomic capture");
1477                  },
1478              },
1479              atomicConstruct.u);
1480 }
1481 
genOpenMPConstruct(Fortran::lower::AbstractConverter & converter,Fortran::lower::pft::Evaluation & eval,const Fortran::parser::OpenMPConstruct & ompConstruct)1482 void Fortran::lower::genOpenMPConstruct(
1483     Fortran::lower::AbstractConverter &converter,
1484     Fortran::lower::pft::Evaluation &eval,
1485     const Fortran::parser::OpenMPConstruct &ompConstruct) {
1486 
1487   std::visit(
1488       common::visitors{
1489           [&](const Fortran::parser::OpenMPStandaloneConstruct
1490                   &standaloneConstruct) {
1491             genOMP(converter, eval, standaloneConstruct);
1492           },
1493           [&](const Fortran::parser::OpenMPSectionsConstruct
1494                   &sectionsConstruct) {
1495             genOMP(converter, eval, sectionsConstruct);
1496           },
1497           [&](const Fortran::parser::OpenMPSectionConstruct &sectionConstruct) {
1498             genOMP(converter, eval, sectionConstruct);
1499           },
1500           [&](const Fortran::parser::OpenMPLoopConstruct &loopConstruct) {
1501             genOMP(converter, eval, loopConstruct);
1502           },
1503           [&](const Fortran::parser::OpenMPDeclarativeAllocate
1504                   &execAllocConstruct) {
1505             TODO(converter.getCurrentLocation(), "OpenMPDeclarativeAllocate");
1506           },
1507           [&](const Fortran::parser::OpenMPExecutableAllocate
1508                   &execAllocConstruct) {
1509             TODO(converter.getCurrentLocation(), "OpenMPExecutableAllocate");
1510           },
1511           [&](const Fortran::parser::OpenMPBlockConstruct &blockConstruct) {
1512             genOMP(converter, eval, blockConstruct);
1513           },
1514           [&](const Fortran::parser::OpenMPAtomicConstruct &atomicConstruct) {
1515             genOMP(converter, eval, atomicConstruct);
1516           },
1517           [&](const Fortran::parser::OpenMPCriticalConstruct
1518                   &criticalConstruct) {
1519             genOMP(converter, eval, criticalConstruct);
1520           },
1521       },
1522       ompConstruct.u);
1523 }
1524 
genThreadprivateOp(Fortran::lower::AbstractConverter & converter,const Fortran::lower::pft::Variable & var)1525 void Fortran::lower::genThreadprivateOp(
1526     Fortran::lower::AbstractConverter &converter,
1527     const Fortran::lower::pft::Variable &var) {
1528   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
1529   mlir::Location currentLocation = converter.getCurrentLocation();
1530 
1531   const Fortran::semantics::Symbol &sym = var.getSymbol();
1532   mlir::Value symThreadprivateValue;
1533   if (const Fortran::semantics::Symbol *common =
1534           Fortran::semantics::FindCommonBlockContaining(sym.GetUltimate())) {
1535     mlir::Value commonValue = converter.getSymbolAddress(*common);
1536     if (mlir::isa<mlir::omp::ThreadprivateOp>(commonValue.getDefiningOp())) {
1537       // Generate ThreadprivateOp for a common block instead of its members and
1538       // only do it once for a common block.
1539       return;
1540     }
1541     // Generate ThreadprivateOp and rebind the common block.
1542     mlir::Value commonThreadprivateValue =
1543         firOpBuilder.create<mlir::omp::ThreadprivateOp>(
1544             currentLocation, commonValue.getType(), commonValue);
1545     converter.bindSymbol(*common, commonThreadprivateValue);
1546     // Generate the threadprivate value for the common block member.
1547     symThreadprivateValue =
1548         genCommonBlockMember(converter, sym, commonThreadprivateValue);
1549   } else {
1550     mlir::Value symValue = converter.getSymbolAddress(sym);
1551     symThreadprivateValue = firOpBuilder.create<mlir::omp::ThreadprivateOp>(
1552         currentLocation, symValue.getType(), symValue);
1553   }
1554 
1555   fir::ExtendedValue sexv = converter.getSymbolExtendedValue(sym);
1556   fir::ExtendedValue symThreadprivateExv =
1557       getExtendedValue(sexv, symThreadprivateValue);
1558   converter.bindSymbol(sym, symThreadprivateExv);
1559 }
1560 
genOpenMPDeclarativeConstruct(Fortran::lower::AbstractConverter & converter,Fortran::lower::pft::Evaluation & eval,const Fortran::parser::OpenMPDeclarativeConstruct & ompDeclConstruct)1561 void Fortran::lower::genOpenMPDeclarativeConstruct(
1562     Fortran::lower::AbstractConverter &converter,
1563     Fortran::lower::pft::Evaluation &eval,
1564     const Fortran::parser::OpenMPDeclarativeConstruct &ompDeclConstruct) {
1565 
1566   std::visit(
1567       common::visitors{
1568           [&](const Fortran::parser::OpenMPDeclarativeAllocate
1569                   &declarativeAllocate) {
1570             TODO(converter.getCurrentLocation(), "OpenMPDeclarativeAllocate");
1571           },
1572           [&](const Fortran::parser::OpenMPDeclareReductionConstruct
1573                   &declareReductionConstruct) {
1574             TODO(converter.getCurrentLocation(),
1575                  "OpenMPDeclareReductionConstruct");
1576           },
1577           [&](const Fortran::parser::OpenMPDeclareSimdConstruct
1578                   &declareSimdConstruct) {
1579             TODO(converter.getCurrentLocation(), "OpenMPDeclareSimdConstruct");
1580           },
1581           [&](const Fortran::parser::OpenMPDeclareTargetConstruct
1582                   &declareTargetConstruct) {
1583             TODO(converter.getCurrentLocation(),
1584                  "OpenMPDeclareTargetConstruct");
1585           },
1586           [&](const Fortran::parser::OpenMPThreadprivate &threadprivate) {
1587             // The directive is lowered when instantiating the variable to
1588             // support the case of threadprivate variable declared in module.
1589           },
1590       },
1591       ompDeclConstruct.u);
1592 }
1593 
1594 // Generate an OpenMP reduction operation. This implementation finds the chain :
1595 // load reduction var -> reduction_operation -> store reduction var and replaces
1596 // it with the reduction operation.
1597 // TODO: Currently assumes it is an integer addition reduction. Generalize this
1598 // for various reduction operation types.
1599 // TODO: Generate the reduction operation during lowering instead of creating
1600 // and removing operations since this is not a robust approach. Also, removing
1601 // ops in the builder (instead of a rewriter) is probably not the best approach.
genOpenMPReduction(Fortran::lower::AbstractConverter & converter,const Fortran::parser::OmpClauseList & clauseList)1602 void Fortran::lower::genOpenMPReduction(
1603     Fortran::lower::AbstractConverter &converter,
1604     const Fortran::parser::OmpClauseList &clauseList) {
1605   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
1606 
1607   for (const auto &clause : clauseList.v) {
1608     if (const auto &reductionClause =
1609             std::get_if<Fortran::parser::OmpClause::Reduction>(&clause.u)) {
1610       const auto &redOperator{std::get<Fortran::parser::OmpReductionOperator>(
1611           reductionClause->v.t)};
1612       const auto &objectList{
1613           std::get<Fortran::parser::OmpObjectList>(reductionClause->v.t)};
1614       if (auto reductionOp =
1615               std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
1616         const auto &intrinsicOp{
1617             std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>(
1618                 reductionOp->u)};
1619         if (intrinsicOp !=
1620             Fortran::parser::DefinedOperator::IntrinsicOperator::Add)
1621           continue;
1622         for (const auto &ompObject : objectList.v) {
1623           if (const auto *name{
1624                   Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
1625             if (const auto *symbol{name->symbol}) {
1626               mlir::Value symVal = converter.getSymbolAddress(*symbol);
1627               mlir::Type redType =
1628                   symVal.getType().cast<fir::ReferenceType>().getEleTy();
1629               if (!redType.isIntOrIndex())
1630                 continue;
1631               for (mlir::OpOperand &use1 : symVal.getUses()) {
1632                 if (auto load = mlir::dyn_cast<fir::LoadOp>(use1.getOwner())) {
1633                   mlir::Value loadVal = load.getRes();
1634                   for (mlir::OpOperand &use2 : loadVal.getUses()) {
1635                     if (auto add = mlir::dyn_cast<mlir::arith::AddIOp>(
1636                             use2.getOwner())) {
1637                       mlir::Value addRes = add.getResult();
1638                       for (mlir::OpOperand &use3 : addRes.getUses()) {
1639                         if (auto store =
1640                                 mlir::dyn_cast<fir::StoreOp>(use3.getOwner())) {
1641                           if (store.getMemref() == symVal) {
1642                             // Chain found! Now replace load->reduction->store
1643                             // with the OpenMP reduction operation.
1644                             mlir::OpBuilder::InsertPoint insertPtDel =
1645                                 firOpBuilder.saveInsertionPoint();
1646                             firOpBuilder.setInsertionPoint(add);
1647                             if (add.getLhs() == loadVal) {
1648                               firOpBuilder.create<mlir::omp::ReductionOp>(
1649                                   add.getLoc(), add.getRhs(), symVal);
1650                             } else {
1651                               firOpBuilder.create<mlir::omp::ReductionOp>(
1652                                   add.getLoc(), add.getLhs(), symVal);
1653                             }
1654                             store.erase();
1655                             add.erase();
1656                             load.erase();
1657                             firOpBuilder.restoreInsertionPoint(insertPtDel);
1658                           }
1659                         }
1660                       }
1661                     }
1662                   }
1663                 }
1664               }
1665             }
1666           }
1667         }
1668       }
1669     }
1670   }
1671 }
1672