1 //===-- OpenMP.cpp -- Open MP directive lowering --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "flang/Lower/OpenMP.h"
14 #include "flang/Common/idioms.h"
15 #include "flang/Lower/Bridge.h"
16 #include "flang/Lower/ConvertExpr.h"
17 #include "flang/Lower/PFTBuilder.h"
18 #include "flang/Lower/StatementContext.h"
19 #include "flang/Optimizer/Builder/BoxValue.h"
20 #include "flang/Optimizer/Builder/FIRBuilder.h"
21 #include "flang/Optimizer/Builder/Todo.h"
22 #include "flang/Parser/parse-tree.h"
23 #include "flang/Semantics/tools.h"
24 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
25 #include "mlir/Dialect/SCF/IR/SCF.h"
26 #include "llvm/Frontend/OpenMP/OMPConstants.h"
27
28 using namespace mlir;
29
getCollapseValue(const Fortran::parser::OmpClauseList & clauseList)30 int64_t Fortran::lower::getCollapseValue(
31 const Fortran::parser::OmpClauseList &clauseList) {
32 for (const auto &clause : clauseList.v) {
33 if (const auto &collapseClause =
34 std::get_if<Fortran::parser::OmpClause::Collapse>(&clause.u)) {
35 const auto *expr = Fortran::semantics::GetExpr(collapseClause->v);
36 return Fortran::evaluate::ToInt64(*expr).value();
37 }
38 }
39 return 1;
40 }
41
42 static const Fortran::parser::Name *
getDesignatorNameIfDataRef(const Fortran::parser::Designator & designator)43 getDesignatorNameIfDataRef(const Fortran::parser::Designator &designator) {
44 const auto *dataRef = std::get_if<Fortran::parser::DataRef>(&designator.u);
45 return dataRef ? std::get_if<Fortran::parser::Name>(&dataRef->u) : nullptr;
46 }
47
48 static Fortran::semantics::Symbol *
getOmpObjectSymbol(const Fortran::parser::OmpObject & ompObject)49 getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject) {
50 Fortran::semantics::Symbol *sym = nullptr;
51 std::visit(Fortran::common::visitors{
52 [&](const Fortran::parser::Designator &designator) {
53 if (const Fortran::parser::Name *name =
54 getDesignatorNameIfDataRef(designator)) {
55 sym = name->symbol;
56 }
57 },
58 [&](const Fortran::parser::Name &name) { sym = name.symbol; }},
59 ompObject.u);
60 return sym;
61 }
62
63 template <typename T>
createPrivateVarSyms(Fortran::lower::AbstractConverter & converter,const T * clause,Block * lastPrivBlock=nullptr)64 static void createPrivateVarSyms(Fortran::lower::AbstractConverter &converter,
65 const T *clause,
66 Block *lastPrivBlock = nullptr) {
67 const Fortran::parser::OmpObjectList &ompObjectList = clause->v;
68 for (const Fortran::parser::OmpObject &ompObject : ompObjectList.v) {
69 Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject);
70 // Privatization for symbols which are pre-determined (like loop index
71 // variables) happen separately, for everything else privatize here.
72 if (sym->test(Fortran::semantics::Symbol::Flag::OmpPreDetermined))
73 continue;
74 bool success = converter.createHostAssociateVarClone(*sym);
75 (void)success;
76 assert(success && "Privatization failed due to existing binding");
77 if constexpr (std::is_same_v<T, Fortran::parser::OmpClause::Firstprivate>) {
78 converter.copyHostAssociateVar(*sym);
79 } else if constexpr (std::is_same_v<
80 T, Fortran::parser::OmpClause::Lastprivate>) {
81 converter.copyHostAssociateVar(*sym, lastPrivBlock);
82 }
83 }
84 }
85
86 template <typename Op>
privatizeVars(Op & op,Fortran::lower::AbstractConverter & converter,const Fortran::parser::OmpClauseList & opClauseList)87 static bool privatizeVars(Op &op, Fortran::lower::AbstractConverter &converter,
88 const Fortran::parser::OmpClauseList &opClauseList) {
89 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
90 auto insPt = firOpBuilder.saveInsertionPoint();
91 firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock());
92 bool hasFirstPrivateOp = false;
93 bool hasLastPrivateOp = false;
94 // We need just one ICmpOp for multiple LastPrivate clauses.
95 mlir::arith::CmpIOp cmpOp;
96
97 for (const Fortran::parser::OmpClause &clause : opClauseList.v) {
98 if (const auto &privateClause =
99 std::get_if<Fortran::parser::OmpClause::Private>(&clause.u)) {
100 createPrivateVarSyms(converter, privateClause);
101 } else if (const auto &firstPrivateClause =
102 std::get_if<Fortran::parser::OmpClause::Firstprivate>(
103 &clause.u)) {
104 createPrivateVarSyms(converter, firstPrivateClause);
105 hasFirstPrivateOp = true;
106 } else if (const auto &lastPrivateClause =
107 std::get_if<Fortran::parser::OmpClause::Lastprivate>(
108 &clause.u)) {
109 // TODO: Add lastprivate support for sections construct, simd construct
110 if (std::is_same_v<Op, omp::WsLoopOp>) {
111 omp::WsLoopOp *wsLoopOp = dyn_cast<omp::WsLoopOp>(&op);
112 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
113 auto insPt = firOpBuilder.saveInsertionPoint();
114
115 // Our goal here is to introduce the following control flow
116 // just before exiting the worksharing loop.
117 // Say our wsloop is as follows:
118 //
119 // omp.wsloop {
120 // ...
121 // store
122 // omp.yield
123 // }
124 //
125 // We want to convert it to the following:
126 //
127 // omp.wsloop {
128 // ...
129 // store
130 // %cmp = llvm.icmp "eq" %iv %ub
131 // scf.if %cmp {
132 // ^%lpv_update_blk:
133 // }
134 // omp.yield
135 // }
136
137 Operation *lastOper = wsLoopOp->region().back().getTerminator();
138
139 firOpBuilder.setInsertionPoint(lastOper);
140
141 // TODO: The following will not work when there is collapse present.
142 // Have to modify this in future.
143 for (const Fortran::parser::OmpClause &clause : opClauseList.v)
144 if (const auto &collapseClause =
145 std::get_if<Fortran::parser::OmpClause::Collapse>(&clause.u))
146 TODO(converter.getCurrentLocation(),
147 "Collapse clause with lastprivate");
148 // Only generate the compare once in presence of multiple LastPrivate
149 // clauses
150 if (!hasLastPrivateOp) {
151 cmpOp = firOpBuilder.create<mlir::arith::CmpIOp>(
152 wsLoopOp->getLoc(), mlir::arith::CmpIPredicate::eq,
153 wsLoopOp->getRegion().front().getArguments()[0],
154 wsLoopOp->upperBound()[0]);
155 }
156 mlir::scf::IfOp ifOp = firOpBuilder.create<mlir::scf::IfOp>(
157 wsLoopOp->getLoc(), cmpOp, /*else*/ false);
158
159 firOpBuilder.restoreInsertionPoint(insPt);
160 createPrivateVarSyms(converter, lastPrivateClause,
161 &(ifOp.getThenRegion().front()));
162 } else {
163 TODO(converter.getCurrentLocation(),
164 "lastprivate clause in constructs other than work-share loop");
165 }
166 hasLastPrivateOp = true;
167 }
168 }
169 if (hasFirstPrivateOp)
170 firOpBuilder.create<mlir::omp::BarrierOp>(converter.getCurrentLocation());
171 firOpBuilder.restoreInsertionPoint(insPt);
172 return hasLastPrivateOp;
173 }
174
175 /// The COMMON block is a global structure. \p commonValue is the base address
176 /// of the the COMMON block. As the offset from the symbol \p sym, generate the
177 /// COMMON block member value (commonValue + offset) for the symbol.
178 /// FIXME: Share the code with `instantiateCommon` in ConvertVariable.cpp.
179 static mlir::Value
genCommonBlockMember(Fortran::lower::AbstractConverter & converter,const Fortran::semantics::Symbol & sym,mlir::Value commonValue)180 genCommonBlockMember(Fortran::lower::AbstractConverter &converter,
181 const Fortran::semantics::Symbol &sym,
182 mlir::Value commonValue) {
183 auto &firOpBuilder = converter.getFirOpBuilder();
184 mlir::Location currentLocation = converter.getCurrentLocation();
185 mlir::IntegerType i8Ty = firOpBuilder.getIntegerType(8);
186 mlir::Type i8Ptr = firOpBuilder.getRefType(i8Ty);
187 mlir::Type seqTy = firOpBuilder.getRefType(firOpBuilder.getVarLenSeqTy(i8Ty));
188 mlir::Value base =
189 firOpBuilder.createConvert(currentLocation, seqTy, commonValue);
190 std::size_t byteOffset = sym.GetUltimate().offset();
191 mlir::Value offs = firOpBuilder.createIntegerConstant(
192 currentLocation, firOpBuilder.getIndexType(), byteOffset);
193 mlir::Value varAddr = firOpBuilder.create<fir::CoordinateOp>(
194 currentLocation, i8Ptr, base, mlir::ValueRange{offs});
195 mlir::Type symType = converter.genType(sym);
196 return firOpBuilder.createConvert(currentLocation,
197 firOpBuilder.getRefType(symType), varAddr);
198 }
199
200 // Get the extended value for \p val by extracting additional variable
201 // information from \p base.
getExtendedValue(fir::ExtendedValue base,mlir::Value val)202 static fir::ExtendedValue getExtendedValue(fir::ExtendedValue base,
203 mlir::Value val) {
204 return base.match(
205 [&](const fir::MutableBoxValue &box) -> fir::ExtendedValue {
206 return fir::MutableBoxValue(val, box.nonDeferredLenParams(), {});
207 },
208 [&](const auto &) -> fir::ExtendedValue {
209 return fir::substBase(base, val);
210 });
211 }
212
threadPrivatizeVars(Fortran::lower::AbstractConverter & converter,Fortran::lower::pft::Evaluation & eval)213 static void threadPrivatizeVars(Fortran::lower::AbstractConverter &converter,
214 Fortran::lower::pft::Evaluation &eval) {
215 auto &firOpBuilder = converter.getFirOpBuilder();
216 mlir::Location currentLocation = converter.getCurrentLocation();
217 auto insPt = firOpBuilder.saveInsertionPoint();
218 firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock());
219
220 // Get the original ThreadprivateOp corresponding to the symbol and use the
221 // symbol value from that opeartion to create one ThreadprivateOp copy
222 // operation inside the parallel region.
223 auto genThreadprivateOp = [&](Fortran::lower::SymbolRef sym) -> mlir::Value {
224 mlir::Value symOriThreadprivateValue = converter.getSymbolAddress(sym);
225 mlir::Operation *op = symOriThreadprivateValue.getDefiningOp();
226 assert(mlir::isa<mlir::omp::ThreadprivateOp>(op) &&
227 "The threadprivate operation not created");
228 mlir::Value symValue =
229 mlir::dyn_cast<mlir::omp::ThreadprivateOp>(op).sym_addr();
230 return firOpBuilder.create<mlir::omp::ThreadprivateOp>(
231 currentLocation, symValue.getType(), symValue);
232 };
233
234 llvm::SetVector<const Fortran::semantics::Symbol *> threadprivateSyms;
235 converter.collectSymbolSet(eval, threadprivateSyms,
236 Fortran::semantics::Symbol::Flag::OmpThreadprivate,
237 /*isUltimateSymbol=*/false);
238 std::set<Fortran::semantics::SourceName> threadprivateSymNames;
239
240 // For a COMMON block, the ThreadprivateOp is generated for itself instead of
241 // its members, so only bind the value of the new copied ThreadprivateOp
242 // inside the parallel region to the common block symbol only once for
243 // multiple members in one COMMON block.
244 llvm::SetVector<const Fortran::semantics::Symbol *> commonSyms;
245 for (std::size_t i = 0; i < threadprivateSyms.size(); i++) {
246 auto sym = threadprivateSyms[i];
247 mlir::Value symThreadprivateValue;
248 // The variable may be used more than once, and each reference has one
249 // symbol with the same name. Only do once for references of one variable.
250 if (threadprivateSymNames.find(sym->name()) != threadprivateSymNames.end())
251 continue;
252 threadprivateSymNames.insert(sym->name());
253 if (const Fortran::semantics::Symbol *common =
254 Fortran::semantics::FindCommonBlockContaining(sym->GetUltimate())) {
255 mlir::Value commonThreadprivateValue;
256 if (commonSyms.contains(common)) {
257 commonThreadprivateValue = converter.getSymbolAddress(*common);
258 } else {
259 commonThreadprivateValue = genThreadprivateOp(*common);
260 converter.bindSymbol(*common, commonThreadprivateValue);
261 commonSyms.insert(common);
262 }
263 symThreadprivateValue =
264 genCommonBlockMember(converter, *sym, commonThreadprivateValue);
265 } else {
266 symThreadprivateValue = genThreadprivateOp(*sym);
267 }
268
269 fir::ExtendedValue sexv = converter.getSymbolExtendedValue(*sym);
270 fir::ExtendedValue symThreadprivateExv =
271 getExtendedValue(sexv, symThreadprivateValue);
272 converter.bindSymbol(*sym, symThreadprivateExv);
273 }
274
275 firOpBuilder.restoreInsertionPoint(insPt);
276 }
277
278 static void
genCopyinClause(Fortran::lower::AbstractConverter & converter,const Fortran::parser::OmpClauseList & opClauseList)279 genCopyinClause(Fortran::lower::AbstractConverter &converter,
280 const Fortran::parser::OmpClauseList &opClauseList) {
281 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
282 mlir::OpBuilder::InsertPoint insPt = firOpBuilder.saveInsertionPoint();
283 firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock());
284 bool hasCopyin = false;
285 for (const Fortran::parser::OmpClause &clause : opClauseList.v) {
286 if (const auto ©inClause =
287 std::get_if<Fortran::parser::OmpClause::Copyin>(&clause.u)) {
288 hasCopyin = true;
289 const Fortran::parser::OmpObjectList &ompObjectList = copyinClause->v;
290 for (const Fortran::parser::OmpObject &ompObject : ompObjectList.v) {
291 Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject);
292 if (sym->has<Fortran::semantics::CommonBlockDetails>())
293 TODO(converter.getCurrentLocation(), "common block in Copyin clause");
294 if (Fortran::semantics::IsAllocatableOrPointer(sym->GetUltimate()))
295 TODO(converter.getCurrentLocation(),
296 "pointer or allocatable variables in Copyin clause");
297 assert(sym->has<Fortran::semantics::HostAssocDetails>() &&
298 "No host-association found");
299 converter.copyHostAssociateVar(*sym);
300 }
301 }
302 }
303 // [OMP 5.0, 2.19.6.1] The copy is done after the team is formed and prior to
304 // the execution of the associated structured block. Emit implicit barrier to
305 // synchronize threads and avoid data races on propagation master's thread
306 // values of threadprivate variables to local instances of that variables of
307 // all other implicit threads.
308 if (hasCopyin)
309 firOpBuilder.create<mlir::omp::BarrierOp>(converter.getCurrentLocation());
310 firOpBuilder.restoreInsertionPoint(insPt);
311 }
312
genObjectList(const Fortran::parser::OmpObjectList & objectList,Fortran::lower::AbstractConverter & converter,llvm::SmallVectorImpl<Value> & operands)313 static void genObjectList(const Fortran::parser::OmpObjectList &objectList,
314 Fortran::lower::AbstractConverter &converter,
315 llvm::SmallVectorImpl<Value> &operands) {
316 auto addOperands = [&](Fortran::lower::SymbolRef sym) {
317 const mlir::Value variable = converter.getSymbolAddress(sym);
318 if (variable) {
319 operands.push_back(variable);
320 } else {
321 if (const auto *details =
322 sym->detailsIf<Fortran::semantics::HostAssocDetails>()) {
323 operands.push_back(converter.getSymbolAddress(details->symbol()));
324 converter.copySymbolBinding(details->symbol(), sym);
325 }
326 }
327 };
328 for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
329 Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject);
330 addOperands(*sym);
331 }
332 }
333
getLoopVarType(Fortran::lower::AbstractConverter & converter,std::size_t loopVarTypeSize)334 static mlir::Type getLoopVarType(Fortran::lower::AbstractConverter &converter,
335 std::size_t loopVarTypeSize) {
336 // OpenMP runtime requires 32-bit or 64-bit loop variables.
337 loopVarTypeSize = loopVarTypeSize * 8;
338 if (loopVarTypeSize < 32) {
339 loopVarTypeSize = 32;
340 } else if (loopVarTypeSize > 64) {
341 loopVarTypeSize = 64;
342 mlir::emitWarning(converter.getCurrentLocation(),
343 "OpenMP loop iteration variable cannot have more than 64 "
344 "bits size and will be narrowed into 64 bits.");
345 }
346 assert((loopVarTypeSize == 32 || loopVarTypeSize == 64) &&
347 "OpenMP loop iteration variable size must be transformed into 32-bit "
348 "or 64-bit");
349 return converter.getFirOpBuilder().getIntegerType(loopVarTypeSize);
350 }
351
352 /// Create empty blocks for the current region.
353 /// These blocks replace blocks parented to an enclosing region.
createEmptyRegionBlocks(fir::FirOpBuilder & firOpBuilder,std::list<Fortran::lower::pft::Evaluation> & evaluationList)354 void createEmptyRegionBlocks(
355 fir::FirOpBuilder &firOpBuilder,
356 std::list<Fortran::lower::pft::Evaluation> &evaluationList) {
357 auto *region = &firOpBuilder.getRegion();
358 for (auto &eval : evaluationList) {
359 if (eval.block) {
360 if (eval.block->empty()) {
361 eval.block->erase();
362 eval.block = firOpBuilder.createBlock(region);
363 } else {
364 [[maybe_unused]] auto &terminatorOp = eval.block->back();
365 assert((mlir::isa<mlir::omp::TerminatorOp>(terminatorOp) ||
366 mlir::isa<mlir::omp::YieldOp>(terminatorOp)) &&
367 "expected terminator op");
368 }
369 }
370 if (!eval.isDirective() && eval.hasNestedEvaluations())
371 createEmptyRegionBlocks(firOpBuilder, eval.getNestedEvaluations());
372 }
373 }
374
resetBeforeTerminator(fir::FirOpBuilder & firOpBuilder,mlir::Operation * storeOp,mlir::Block & block)375 void resetBeforeTerminator(fir::FirOpBuilder &firOpBuilder,
376 mlir::Operation *storeOp, mlir::Block &block) {
377 if (storeOp)
378 firOpBuilder.setInsertionPointAfter(storeOp);
379 else
380 firOpBuilder.setInsertionPointToStart(&block);
381 }
382
383 /// Create the body (block) for an OpenMP Operation.
384 ///
385 /// \param [in] op - the operation the body belongs to.
386 /// \param [inout] converter - converter to use for the clauses.
387 /// \param [in] loc - location in source code.
388 /// \param [in] eval - current PFT node/evaluation.
389 /// \oaran [in] clauses - list of clauses to process.
390 /// \param [in] args - block arguments (induction variable[s]) for the
391 //// region.
392 /// \param [in] outerCombined - is this an outer operation - prevents
393 /// privatization.
394 template <typename Op>
395 static void
createBodyOfOp(Op & op,Fortran::lower::AbstractConverter & converter,mlir::Location & loc,Fortran::lower::pft::Evaluation & eval,const Fortran::parser::OmpClauseList * clauses=nullptr,const SmallVector<const Fortran::semantics::Symbol * > & args={},bool outerCombined=false)396 createBodyOfOp(Op &op, Fortran::lower::AbstractConverter &converter,
397 mlir::Location &loc, Fortran::lower::pft::Evaluation &eval,
398 const Fortran::parser::OmpClauseList *clauses = nullptr,
399 const SmallVector<const Fortran::semantics::Symbol *> &args = {},
400 bool outerCombined = false) {
401 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
402 // If an argument for the region is provided then create the block with that
403 // argument. Also update the symbol's address with the mlir argument value.
404 // e.g. For loops the argument is the induction variable. And all further
405 // uses of the induction variable should use this mlir value.
406 mlir::Operation *storeOp = nullptr;
407 if (args.size()) {
408 std::size_t loopVarTypeSize = 0;
409 for (const Fortran::semantics::Symbol *arg : args)
410 loopVarTypeSize = std::max(loopVarTypeSize, arg->GetUltimate().size());
411 mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
412 SmallVector<Type> tiv;
413 SmallVector<Location> locs;
414 for (int i = 0; i < (int)args.size(); i++) {
415 tiv.push_back(loopVarType);
416 locs.push_back(loc);
417 }
418 firOpBuilder.createBlock(&op.getRegion(), {}, tiv, locs);
419 int argIndex = 0;
420 // The argument is not currently in memory, so make a temporary for the
421 // argument, and store it there, then bind that location to the argument.
422 for (const Fortran::semantics::Symbol *arg : args) {
423 mlir::Value val =
424 fir::getBase(op.getRegion().front().getArgument(argIndex));
425 mlir::Value temp = firOpBuilder.createTemporary(
426 loc, loopVarType,
427 llvm::ArrayRef<mlir::NamedAttribute>{
428 Fortran::lower::getAdaptToByRefAttr(firOpBuilder)});
429 storeOp = firOpBuilder.create<fir::StoreOp>(loc, val, temp);
430 converter.bindSymbol(*arg, temp);
431 argIndex++;
432 }
433 } else {
434 firOpBuilder.createBlock(&op.getRegion());
435 }
436 // Set the insert for the terminator operation to go at the end of the
437 // block - this is either empty or the block with the stores above,
438 // the end of the block works for both.
439 mlir::Block &block = op.getRegion().back();
440 firOpBuilder.setInsertionPointToEnd(&block);
441
442 // If it is an unstructured region and is not the outer region of a combined
443 // construct, create empty blocks for all evaluations.
444 if (eval.lowerAsUnstructured() && !outerCombined)
445 createEmptyRegionBlocks(firOpBuilder, eval.getNestedEvaluations());
446
447 // Insert the terminator.
448 if constexpr (std::is_same_v<Op, omp::WsLoopOp> ||
449 std::is_same_v<Op, omp::SimdLoopOp>) {
450 mlir::ValueRange results;
451 firOpBuilder.create<mlir::omp::YieldOp>(loc, results);
452 } else {
453 firOpBuilder.create<mlir::omp::TerminatorOp>(loc);
454 }
455
456 // Reset the insert point to before the terminator.
457 resetBeforeTerminator(firOpBuilder, storeOp, block);
458
459 // Handle privatization. Do not privatize if this is the outer operation.
460 if (clauses && !outerCombined) {
461 bool lastPrivateOp = privatizeVars(op, converter, *clauses);
462 // LastPrivatization, due to introduction of
463 // new control flow, changes the insertion point,
464 // thus restore it.
465 // TODO: Clean up later a bit to avoid this many sets and resets.
466 if (lastPrivateOp)
467 resetBeforeTerminator(firOpBuilder, storeOp, block);
468 }
469
470 if constexpr (std::is_same_v<Op, omp::ParallelOp>) {
471 threadPrivatizeVars(converter, eval);
472 if (clauses)
473 genCopyinClause(converter, *clauses);
474 }
475 }
476
genOMP(Fortran::lower::AbstractConverter & converter,Fortran::lower::pft::Evaluation & eval,const Fortran::parser::OpenMPSimpleStandaloneConstruct & simpleStandaloneConstruct)477 static void genOMP(Fortran::lower::AbstractConverter &converter,
478 Fortran::lower::pft::Evaluation &eval,
479 const Fortran::parser::OpenMPSimpleStandaloneConstruct
480 &simpleStandaloneConstruct) {
481 const auto &directive =
482 std::get<Fortran::parser::OmpSimpleStandaloneDirective>(
483 simpleStandaloneConstruct.t);
484 switch (directive.v) {
485 default:
486 break;
487 case llvm::omp::Directive::OMPD_barrier:
488 converter.getFirOpBuilder().create<mlir::omp::BarrierOp>(
489 converter.getCurrentLocation());
490 break;
491 case llvm::omp::Directive::OMPD_taskwait:
492 converter.getFirOpBuilder().create<mlir::omp::TaskwaitOp>(
493 converter.getCurrentLocation());
494 break;
495 case llvm::omp::Directive::OMPD_taskyield:
496 converter.getFirOpBuilder().create<mlir::omp::TaskyieldOp>(
497 converter.getCurrentLocation());
498 break;
499 case llvm::omp::Directive::OMPD_target_enter_data:
500 TODO(converter.getCurrentLocation(), "OMPD_target_enter_data");
501 case llvm::omp::Directive::OMPD_target_exit_data:
502 TODO(converter.getCurrentLocation(), "OMPD_target_exit_data");
503 case llvm::omp::Directive::OMPD_target_update:
504 TODO(converter.getCurrentLocation(), "OMPD_target_update");
505 case llvm::omp::Directive::OMPD_ordered:
506 TODO(converter.getCurrentLocation(), "OMPD_ordered");
507 }
508 }
509
510 static void
genAllocateClause(Fortran::lower::AbstractConverter & converter,const Fortran::parser::OmpAllocateClause & ompAllocateClause,SmallVector<Value> & allocatorOperands,SmallVector<Value> & allocateOperands)511 genAllocateClause(Fortran::lower::AbstractConverter &converter,
512 const Fortran::parser::OmpAllocateClause &ompAllocateClause,
513 SmallVector<Value> &allocatorOperands,
514 SmallVector<Value> &allocateOperands) {
515 auto &firOpBuilder = converter.getFirOpBuilder();
516 auto currentLocation = converter.getCurrentLocation();
517 Fortran::lower::StatementContext stmtCtx;
518
519 mlir::Value allocatorOperand;
520 const Fortran::parser::OmpObjectList &ompObjectList =
521 std::get<Fortran::parser::OmpObjectList>(ompAllocateClause.t);
522 const auto &allocatorValue =
523 std::get<std::optional<Fortran::parser::OmpAllocateClause::Allocator>>(
524 ompAllocateClause.t);
525 // Check if allocate clause has allocator specified. If so, add it
526 // to list of allocators, otherwise, add default allocator to
527 // list of allocators.
528 if (allocatorValue) {
529 allocatorOperand = fir::getBase(converter.genExprValue(
530 *Fortran::semantics::GetExpr(allocatorValue->v), stmtCtx));
531 allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(),
532 allocatorOperand);
533 } else {
534 allocatorOperand = firOpBuilder.createIntegerConstant(
535 currentLocation, firOpBuilder.getI32Type(), 1);
536 allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(),
537 allocatorOperand);
538 }
539 genObjectList(ompObjectList, converter, allocateOperands);
540 }
541
542 static void
genOMP(Fortran::lower::AbstractConverter & converter,Fortran::lower::pft::Evaluation & eval,const Fortran::parser::OpenMPStandaloneConstruct & standaloneConstruct)543 genOMP(Fortran::lower::AbstractConverter &converter,
544 Fortran::lower::pft::Evaluation &eval,
545 const Fortran::parser::OpenMPStandaloneConstruct &standaloneConstruct) {
546 std::visit(
547 Fortran::common::visitors{
548 [&](const Fortran::parser::OpenMPSimpleStandaloneConstruct
549 &simpleStandaloneConstruct) {
550 genOMP(converter, eval, simpleStandaloneConstruct);
551 },
552 [&](const Fortran::parser::OpenMPFlushConstruct &flushConstruct) {
553 SmallVector<Value, 4> operandRange;
554 if (const auto &ompObjectList =
555 std::get<std::optional<Fortran::parser::OmpObjectList>>(
556 flushConstruct.t))
557 genObjectList(*ompObjectList, converter, operandRange);
558 const auto &memOrderClause = std::get<std::optional<
559 std::list<Fortran::parser::OmpMemoryOrderClause>>>(
560 flushConstruct.t);
561 if (memOrderClause.has_value() && memOrderClause->size() > 0)
562 TODO(converter.getCurrentLocation(),
563 "Handle OmpMemoryOrderClause");
564 converter.getFirOpBuilder().create<mlir::omp::FlushOp>(
565 converter.getCurrentLocation(), operandRange);
566 },
567 [&](const Fortran::parser::OpenMPCancelConstruct &cancelConstruct) {
568 TODO(converter.getCurrentLocation(), "OpenMPCancelConstruct");
569 },
570 [&](const Fortran::parser::OpenMPCancellationPointConstruct
571 &cancellationPointConstruct) {
572 TODO(converter.getCurrentLocation(), "OpenMPCancelConstruct");
573 },
574 },
575 standaloneConstruct.u);
576 }
577
genProcBindKindAttr(fir::FirOpBuilder & firOpBuilder,const Fortran::parser::OmpClause::ProcBind * procBindClause)578 static omp::ClauseProcBindKindAttr genProcBindKindAttr(
579 fir::FirOpBuilder &firOpBuilder,
580 const Fortran::parser::OmpClause::ProcBind *procBindClause) {
581 omp::ClauseProcBindKind pbKind;
582 switch (procBindClause->v.v) {
583 case Fortran::parser::OmpProcBindClause::Type::Master:
584 pbKind = omp::ClauseProcBindKind::Master;
585 break;
586 case Fortran::parser::OmpProcBindClause::Type::Close:
587 pbKind = omp::ClauseProcBindKind::Close;
588 break;
589 case Fortran::parser::OmpProcBindClause::Type::Spread:
590 pbKind = omp::ClauseProcBindKind::Spread;
591 break;
592 case Fortran::parser::OmpProcBindClause::Type::Primary:
593 pbKind = omp::ClauseProcBindKind::Primary;
594 break;
595 }
596 return omp::ClauseProcBindKindAttr::get(firOpBuilder.getContext(), pbKind);
597 }
598
599 static mlir::Value
getIfClauseOperand(Fortran::lower::AbstractConverter & converter,Fortran::lower::StatementContext & stmtCtx,const Fortran::parser::OmpClause::If * ifClause)600 getIfClauseOperand(Fortran::lower::AbstractConverter &converter,
601 Fortran::lower::StatementContext &stmtCtx,
602 const Fortran::parser::OmpClause::If *ifClause) {
603 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
604 mlir::Location currentLocation = converter.getCurrentLocation();
605 auto &expr = std::get<Fortran::parser::ScalarLogicalExpr>(ifClause->v.t);
606 mlir::Value ifVal = fir::getBase(
607 converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx));
608 return firOpBuilder.createConvert(currentLocation, firOpBuilder.getI1Type(),
609 ifVal);
610 }
611
612 /* When parallel is used in a combined construct, then use this function to
613 * create the parallel operation. It handles the parallel specific clauses
614 * and leaves the rest for handling at the inner operations.
615 * TODO: Refactor clause handling
616 */
617 template <typename Directive>
618 static void
createCombinedParallelOp(Fortran::lower::AbstractConverter & converter,Fortran::lower::pft::Evaluation & eval,const Directive & directive)619 createCombinedParallelOp(Fortran::lower::AbstractConverter &converter,
620 Fortran::lower::pft::Evaluation &eval,
621 const Directive &directive) {
622 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
623 mlir::Location currentLocation = converter.getCurrentLocation();
624 Fortran::lower::StatementContext stmtCtx;
625 llvm::ArrayRef<mlir::Type> argTy;
626 mlir::Value ifClauseOperand, numThreadsClauseOperand;
627 SmallVector<Value> allocatorOperands, allocateOperands;
628 mlir::omp::ClauseProcBindKindAttr procBindKindAttr;
629 const auto &opClauseList =
630 std::get<Fortran::parser::OmpClauseList>(directive.t);
631 // TODO: Handle the following clauses
632 // 1. default
633 // Note: rest of the clauses are handled when the inner operation is created
634 for (const Fortran::parser::OmpClause &clause : opClauseList.v) {
635 if (const auto &ifClause =
636 std::get_if<Fortran::parser::OmpClause::If>(&clause.u)) {
637 ifClauseOperand = getIfClauseOperand(converter, stmtCtx, ifClause);
638 } else if (const auto &numThreadsClause =
639 std::get_if<Fortran::parser::OmpClause::NumThreads>(
640 &clause.u)) {
641 numThreadsClauseOperand = fir::getBase(converter.genExprValue(
642 *Fortran::semantics::GetExpr(numThreadsClause->v), stmtCtx));
643 } else if (const auto &procBindClause =
644 std::get_if<Fortran::parser::OmpClause::ProcBind>(
645 &clause.u)) {
646 procBindKindAttr = genProcBindKindAttr(firOpBuilder, procBindClause);
647 }
648 }
649 // Create and insert the operation.
650 auto parallelOp = firOpBuilder.create<mlir::omp::ParallelOp>(
651 currentLocation, argTy, ifClauseOperand, numThreadsClauseOperand,
652 allocateOperands, allocatorOperands, /*reduction_vars=*/ValueRange(),
653 /*reductions=*/nullptr, procBindKindAttr);
654
655 createBodyOfOp<omp::ParallelOp>(parallelOp, converter, currentLocation, eval,
656 &opClauseList, /*iv=*/{},
657 /*isCombined=*/true);
658 }
659
660 static void
genOMP(Fortran::lower::AbstractConverter & converter,Fortran::lower::pft::Evaluation & eval,const Fortran::parser::OpenMPBlockConstruct & blockConstruct)661 genOMP(Fortran::lower::AbstractConverter &converter,
662 Fortran::lower::pft::Evaluation &eval,
663 const Fortran::parser::OpenMPBlockConstruct &blockConstruct) {
664 const auto &beginBlockDirective =
665 std::get<Fortran::parser::OmpBeginBlockDirective>(blockConstruct.t);
666 const auto &blockDirective =
667 std::get<Fortran::parser::OmpBlockDirective>(beginBlockDirective.t);
668 const auto &endBlockDirective =
669 std::get<Fortran::parser::OmpEndBlockDirective>(blockConstruct.t);
670 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
671 mlir::Location currentLocation = converter.getCurrentLocation();
672
673 Fortran::lower::StatementContext stmtCtx;
674 llvm::ArrayRef<mlir::Type> argTy;
675 mlir::Value ifClauseOperand, numThreadsClauseOperand, finalClauseOperand,
676 priorityClauseOperand;
677 mlir::omp::ClauseProcBindKindAttr procBindKindAttr;
678 SmallVector<Value> allocateOperands, allocatorOperands;
679 mlir::UnitAttr nowaitAttr, untiedAttr, mergeableAttr;
680
681 const auto &opClauseList =
682 std::get<Fortran::parser::OmpClauseList>(beginBlockDirective.t);
683 for (const auto &clause : opClauseList.v) {
684 if (const auto &ifClause =
685 std::get_if<Fortran::parser::OmpClause::If>(&clause.u)) {
686 ifClauseOperand = getIfClauseOperand(converter, stmtCtx, ifClause);
687 } else if (const auto &numThreadsClause =
688 std::get_if<Fortran::parser::OmpClause::NumThreads>(
689 &clause.u)) {
690 // OMPIRBuilder expects `NUM_THREAD` clause as a `Value`.
691 numThreadsClauseOperand = fir::getBase(converter.genExprValue(
692 *Fortran::semantics::GetExpr(numThreadsClause->v), stmtCtx));
693 } else if (const auto &procBindClause =
694 std::get_if<Fortran::parser::OmpClause::ProcBind>(
695 &clause.u)) {
696 procBindKindAttr = genProcBindKindAttr(firOpBuilder, procBindClause);
697 } else if (const auto &allocateClause =
698 std::get_if<Fortran::parser::OmpClause::Allocate>(
699 &clause.u)) {
700 genAllocateClause(converter, allocateClause->v, allocatorOperands,
701 allocateOperands);
702 } else if (std::get_if<Fortran::parser::OmpClause::Private>(&clause.u) ||
703 std::get_if<Fortran::parser::OmpClause::Firstprivate>(
704 &clause.u) ||
705 std::get_if<Fortran::parser::OmpClause::Copyin>(&clause.u)) {
706 // Privatisation and copyin clauses are handled elsewhere.
707 continue;
708 } else if (std::get_if<Fortran::parser::OmpClause::Shared>(&clause.u)) {
709 // Shared is the default behavior in the IR, so no handling is required.
710 continue;
711 } else if (const auto &defaultClause =
712 std::get_if<Fortran::parser::OmpClause::Default>(
713 &clause.u)) {
714 if ((defaultClause->v.v ==
715 Fortran::parser::OmpDefaultClause::Type::Shared) ||
716 (defaultClause->v.v ==
717 Fortran::parser::OmpDefaultClause::Type::None)) {
718 // Default clause with shared or none do not require any handling since
719 // Shared is the default behavior in the IR and None is only required
720 // for semantic checks.
721 continue;
722 }
723 } else if (std::get_if<Fortran::parser::OmpClause::Threads>(&clause.u)) {
724 // Nothing needs to be done for threads clause.
725 continue;
726 } else if (const auto &finalClause =
727 std::get_if<Fortran::parser::OmpClause::Final>(&clause.u)) {
728 mlir::Value finalVal = fir::getBase(converter.genExprValue(
729 *Fortran::semantics::GetExpr(finalClause->v), stmtCtx));
730 finalClauseOperand = firOpBuilder.createConvert(
731 currentLocation, firOpBuilder.getI1Type(), finalVal);
732 } else if (std::get_if<Fortran::parser::OmpClause::Untied>(&clause.u)) {
733 untiedAttr = firOpBuilder.getUnitAttr();
734 } else if (std::get_if<Fortran::parser::OmpClause::Mergeable>(&clause.u)) {
735 mergeableAttr = firOpBuilder.getUnitAttr();
736 } else if (const auto &priorityClause =
737 std::get_if<Fortran::parser::OmpClause::Priority>(
738 &clause.u)) {
739 priorityClauseOperand = fir::getBase(converter.genExprValue(
740 *Fortran::semantics::GetExpr(priorityClause->v), stmtCtx));
741 } else {
742 TODO(currentLocation, "OpenMP Block construct clauses");
743 }
744 }
745
746 for (const auto &clause :
747 std::get<Fortran::parser::OmpClauseList>(endBlockDirective.t).v) {
748 if (std::get_if<Fortran::parser::OmpClause::Nowait>(&clause.u))
749 nowaitAttr = firOpBuilder.getUnitAttr();
750 }
751
752 if (blockDirective.v == llvm::omp::OMPD_parallel) {
753 // Create and insert the operation.
754 auto parallelOp = firOpBuilder.create<mlir::omp::ParallelOp>(
755 currentLocation, argTy, ifClauseOperand, numThreadsClauseOperand,
756 allocateOperands, allocatorOperands, /*reduction_vars=*/ValueRange(),
757 /*reductions=*/nullptr, procBindKindAttr);
758 createBodyOfOp<omp::ParallelOp>(parallelOp, converter, currentLocation,
759 eval, &opClauseList);
760 } else if (blockDirective.v == llvm::omp::OMPD_master) {
761 auto masterOp =
762 firOpBuilder.create<mlir::omp::MasterOp>(currentLocation, argTy);
763 createBodyOfOp<omp::MasterOp>(masterOp, converter, currentLocation, eval);
764 } else if (blockDirective.v == llvm::omp::OMPD_single) {
765 auto singleOp = firOpBuilder.create<mlir::omp::SingleOp>(
766 currentLocation, allocateOperands, allocatorOperands, nowaitAttr);
767 createBodyOfOp<omp::SingleOp>(singleOp, converter, currentLocation, eval);
768 } else if (blockDirective.v == llvm::omp::OMPD_ordered) {
769 auto orderedOp = firOpBuilder.create<mlir::omp::OrderedRegionOp>(
770 currentLocation, /*simd=*/nullptr);
771 createBodyOfOp<omp::OrderedRegionOp>(orderedOp, converter, currentLocation,
772 eval);
773 } else if (blockDirective.v == llvm::omp::OMPD_task) {
774 auto taskOp = firOpBuilder.create<mlir::omp::TaskOp>(
775 currentLocation, ifClauseOperand, finalClauseOperand, untiedAttr,
776 mergeableAttr, /*in_reduction_vars=*/ValueRange(),
777 /*in_reductions=*/nullptr, priorityClauseOperand, allocateOperands,
778 allocatorOperands);
779 createBodyOfOp(taskOp, converter, currentLocation, eval, &opClauseList);
780 } else {
781 TODO(converter.getCurrentLocation(), "Unhandled block directive");
782 }
783 }
784
785 /// Creates an OpenMP reduction declaration and inserts it into the provided
786 /// symbol table. The declaration has a constant initializer with the neutral
787 /// value `initValue`, and the reduction combiner carried over from `reduce`.
788 /// TODO: Generalize this for non-integer types, add atomic region.
createReductionDecl(fir::FirOpBuilder & builder,llvm::StringRef name,mlir::Type type,mlir::Location loc)789 static omp::ReductionDeclareOp createReductionDecl(fir::FirOpBuilder &builder,
790 llvm::StringRef name,
791 mlir::Type type,
792 mlir::Location loc) {
793 OpBuilder::InsertionGuard guard(builder);
794 mlir::ModuleOp module = builder.getModule();
795 mlir::OpBuilder modBuilder(module.getBodyRegion());
796 auto decl = module.lookupSymbol<mlir::omp::ReductionDeclareOp>(name);
797 if (!decl)
798 decl = modBuilder.create<omp::ReductionDeclareOp>(loc, name, type);
799 else
800 return decl;
801
802 builder.createBlock(&decl.initializerRegion(), decl.initializerRegion().end(),
803 {type}, {loc});
804 builder.setInsertionPointToEnd(&decl.initializerRegion().back());
805 Value init = builder.create<mlir::arith::ConstantOp>(
806 loc, type, builder.getIntegerAttr(type, 0));
807 builder.create<omp::YieldOp>(loc, init);
808
809 builder.createBlock(&decl.reductionRegion(), decl.reductionRegion().end(),
810 {type, type}, {loc, loc});
811 builder.setInsertionPointToEnd(&decl.reductionRegion().back());
812 mlir::Value op1 = decl.reductionRegion().front().getArgument(0);
813 mlir::Value op2 = decl.reductionRegion().front().getArgument(1);
814 Value addRes = builder.create<mlir::arith::AddIOp>(loc, op1, op2);
815 builder.create<omp::YieldOp>(loc, addRes);
816 return decl;
817 }
818
819 static mlir::omp::ScheduleModifier
translateModifier(const Fortran::parser::OmpScheduleModifierType & m)820 translateModifier(const Fortran::parser::OmpScheduleModifierType &m) {
821 switch (m.v) {
822 case Fortran::parser::OmpScheduleModifierType::ModType::Monotonic:
823 return mlir::omp::ScheduleModifier::monotonic;
824 case Fortran::parser::OmpScheduleModifierType::ModType::Nonmonotonic:
825 return mlir::omp::ScheduleModifier::nonmonotonic;
826 case Fortran::parser::OmpScheduleModifierType::ModType::Simd:
827 return mlir::omp::ScheduleModifier::simd;
828 }
829 return mlir::omp::ScheduleModifier::none;
830 }
831
832 static mlir::omp::ScheduleModifier
getScheduleModifier(const Fortran::parser::OmpScheduleClause & x)833 getScheduleModifier(const Fortran::parser::OmpScheduleClause &x) {
834 const auto &modifier =
835 std::get<std::optional<Fortran::parser::OmpScheduleModifier>>(x.t);
836 // The input may have the modifier any order, so we look for one that isn't
837 // SIMD. If modifier is not set at all, fall down to the bottom and return
838 // "none".
839 if (modifier) {
840 const auto &modType1 =
841 std::get<Fortran::parser::OmpScheduleModifier::Modifier1>(modifier->t);
842 if (modType1.v.v ==
843 Fortran::parser::OmpScheduleModifierType::ModType::Simd) {
844 const auto &modType2 = std::get<
845 std::optional<Fortran::parser::OmpScheduleModifier::Modifier2>>(
846 modifier->t);
847 if (modType2 &&
848 modType2->v.v !=
849 Fortran::parser::OmpScheduleModifierType::ModType::Simd)
850 return translateModifier(modType2->v);
851
852 return mlir::omp::ScheduleModifier::none;
853 }
854
855 return translateModifier(modType1.v);
856 }
857 return mlir::omp::ScheduleModifier::none;
858 }
859
860 static mlir::omp::ScheduleModifier
getSIMDModifier(const Fortran::parser::OmpScheduleClause & x)861 getSIMDModifier(const Fortran::parser::OmpScheduleClause &x) {
862 const auto &modifier =
863 std::get<std::optional<Fortran::parser::OmpScheduleModifier>>(x.t);
864 // Either of the two possible modifiers in the input can be the SIMD modifier,
865 // so look in either one, and return simd if we find one. Not found = return
866 // "none".
867 if (modifier) {
868 const auto &modType1 =
869 std::get<Fortran::parser::OmpScheduleModifier::Modifier1>(modifier->t);
870 if (modType1.v.v == Fortran::parser::OmpScheduleModifierType::ModType::Simd)
871 return mlir::omp::ScheduleModifier::simd;
872
873 const auto &modType2 = std::get<
874 std::optional<Fortran::parser::OmpScheduleModifier::Modifier2>>(
875 modifier->t);
876 if (modType2 && modType2->v.v ==
877 Fortran::parser::OmpScheduleModifierType::ModType::Simd)
878 return mlir::omp::ScheduleModifier::simd;
879 }
880 return mlir::omp::ScheduleModifier::none;
881 }
882
getReductionName(Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,mlir::Type ty)883 static std::string getReductionName(
884 Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
885 mlir::Type ty) {
886 std::string reductionName;
887 if (intrinsicOp == Fortran::parser::DefinedOperator::IntrinsicOperator::Add)
888 reductionName = "add_reduction";
889 else
890 reductionName = "other_reduction";
891
892 return (llvm::Twine(reductionName) +
893 (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) +
894 llvm::Twine(ty.getIntOrFloatBitWidth()))
895 .str();
896 }
897
genOMP(Fortran::lower::AbstractConverter & converter,Fortran::lower::pft::Evaluation & eval,const Fortran::parser::OpenMPLoopConstruct & loopConstruct)898 static void genOMP(Fortran::lower::AbstractConverter &converter,
899 Fortran::lower::pft::Evaluation &eval,
900 const Fortran::parser::OpenMPLoopConstruct &loopConstruct) {
901
902 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
903 mlir::Location currentLocation = converter.getCurrentLocation();
904 llvm::SmallVector<mlir::Value> lowerBound, upperBound, step, linearVars,
905 linearStepVars, reductionVars;
906 mlir::Value scheduleChunkClauseOperand, ifClauseOperand;
907 mlir::Attribute scheduleClauseOperand, noWaitClauseOperand,
908 orderedClauseOperand, orderClauseOperand;
909 SmallVector<Attribute> reductionDeclSymbols;
910 Fortran::lower::StatementContext stmtCtx;
911 const auto &loopOpClauseList = std::get<Fortran::parser::OmpClauseList>(
912 std::get<Fortran::parser::OmpBeginLoopDirective>(loopConstruct.t).t);
913
914 const auto ompDirective =
915 std::get<Fortran::parser::OmpLoopDirective>(
916 std::get<Fortran::parser::OmpBeginLoopDirective>(loopConstruct.t).t)
917 .v;
918 if (llvm::omp::OMPD_parallel_do == ompDirective) {
919 createCombinedParallelOp<Fortran::parser::OmpBeginLoopDirective>(
920 converter, eval,
921 std::get<Fortran::parser::OmpBeginLoopDirective>(loopConstruct.t));
922 } else if (llvm::omp::OMPD_do != ompDirective &&
923 llvm::omp::OMPD_simd != ompDirective) {
924 TODO(converter.getCurrentLocation(), "Construct enclosing do loop");
925 }
926
927 // Collect the loops to collapse.
928 auto *doConstructEval = &eval.getFirstNestedEvaluation();
929
930 std::int64_t collapseValue =
931 Fortran::lower::getCollapseValue(loopOpClauseList);
932 std::size_t loopVarTypeSize = 0;
933 SmallVector<const Fortran::semantics::Symbol *> iv;
934 do {
935 auto *doLoop = &doConstructEval->getFirstNestedEvaluation();
936 auto *doStmt = doLoop->getIf<Fortran::parser::NonLabelDoStmt>();
937 assert(doStmt && "Expected do loop to be in the nested evaluation");
938 const auto &loopControl =
939 std::get<std::optional<Fortran::parser::LoopControl>>(doStmt->t);
940 const Fortran::parser::LoopControl::Bounds *bounds =
941 std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
942 assert(bounds && "Expected bounds for worksharing do loop");
943 Fortran::lower::StatementContext stmtCtx;
944 lowerBound.push_back(fir::getBase(converter.genExprValue(
945 *Fortran::semantics::GetExpr(bounds->lower), stmtCtx)));
946 upperBound.push_back(fir::getBase(converter.genExprValue(
947 *Fortran::semantics::GetExpr(bounds->upper), stmtCtx)));
948 if (bounds->step) {
949 step.push_back(fir::getBase(converter.genExprValue(
950 *Fortran::semantics::GetExpr(bounds->step), stmtCtx)));
951 } else { // If `step` is not present, assume it as `1`.
952 step.push_back(firOpBuilder.createIntegerConstant(
953 currentLocation, firOpBuilder.getIntegerType(32), 1));
954 }
955 iv.push_back(bounds->name.thing.symbol);
956 loopVarTypeSize = std::max(loopVarTypeSize,
957 bounds->name.thing.symbol->GetUltimate().size());
958
959 collapseValue--;
960 doConstructEval =
961 &*std::next(doConstructEval->getNestedEvaluations().begin());
962 } while (collapseValue > 0);
963
964 for (const auto &clause : loopOpClauseList.v) {
965 if (const auto &scheduleClause =
966 std::get_if<Fortran::parser::OmpClause::Schedule>(&clause.u)) {
967 if (const auto &chunkExpr =
968 std::get<std::optional<Fortran::parser::ScalarIntExpr>>(
969 scheduleClause->v.t)) {
970 if (const auto *expr = Fortran::semantics::GetExpr(*chunkExpr)) {
971 scheduleChunkClauseOperand =
972 fir::getBase(converter.genExprValue(*expr, stmtCtx));
973 }
974 }
975 } else if (const auto &ifClause =
976 std::get_if<Fortran::parser::OmpClause::If>(&clause.u)) {
977 ifClauseOperand = getIfClauseOperand(converter, stmtCtx, ifClause);
978 } else if (const auto &reductionClause =
979 std::get_if<Fortran::parser::OmpClause::Reduction>(
980 &clause.u)) {
981 omp::ReductionDeclareOp decl;
982 const auto &redOperator{std::get<Fortran::parser::OmpReductionOperator>(
983 reductionClause->v.t)};
984 const auto &objectList{
985 std::get<Fortran::parser::OmpObjectList>(reductionClause->v.t)};
986 if (const auto &redDefinedOp =
987 std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
988 const auto &intrinsicOp{
989 std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>(
990 redDefinedOp->u)};
991 if (intrinsicOp !=
992 Fortran::parser::DefinedOperator::IntrinsicOperator::Add)
993 TODO(currentLocation,
994 "Reduction of some intrinsic operators is not supported");
995 for (const auto &ompObject : objectList.v) {
996 if (const auto *name{
997 Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
998 if (const auto *symbol{name->symbol}) {
999 mlir::Value symVal = converter.getSymbolAddress(*symbol);
1000 mlir::Type redType =
1001 symVal.getType().cast<fir::ReferenceType>().getEleTy();
1002 reductionVars.push_back(symVal);
1003 if (redType.isIntOrIndex()) {
1004 decl = createReductionDecl(
1005 firOpBuilder, getReductionName(intrinsicOp, redType),
1006 redType, currentLocation);
1007 } else {
1008 TODO(currentLocation,
1009 "Reduction of some types is not supported");
1010 }
1011 reductionDeclSymbols.push_back(SymbolRefAttr::get(
1012 firOpBuilder.getContext(), decl.sym_name()));
1013 }
1014 }
1015 }
1016 } else {
1017 TODO(currentLocation,
1018 "Reduction of intrinsic procedures is not supported");
1019 }
1020 }
1021 }
1022
1023 // The types of lower bound, upper bound, and step are converted into the
1024 // type of the loop variable if necessary.
1025 mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
1026 for (unsigned it = 0; it < (unsigned)lowerBound.size(); it++) {
1027 lowerBound[it] = firOpBuilder.createConvert(currentLocation, loopVarType,
1028 lowerBound[it]);
1029 upperBound[it] = firOpBuilder.createConvert(currentLocation, loopVarType,
1030 upperBound[it]);
1031 step[it] =
1032 firOpBuilder.createConvert(currentLocation, loopVarType, step[it]);
1033 }
1034
1035 // 2.9.3.1 SIMD construct
1036 // TODO: Support all the clauses
1037 if (llvm::omp::OMPD_simd == ompDirective) {
1038 TypeRange resultType;
1039 auto SimdLoopOp = firOpBuilder.create<mlir::omp::SimdLoopOp>(
1040 currentLocation, resultType, lowerBound, upperBound, step,
1041 ifClauseOperand, /*inclusive=*/firOpBuilder.getUnitAttr());
1042 createBodyOfOp<omp::SimdLoopOp>(SimdLoopOp, converter, currentLocation,
1043 eval, &loopOpClauseList, iv);
1044 return;
1045 }
1046
1047 // FIXME: Add support for following clauses:
1048 // 1. linear
1049 // 2. order
1050 auto wsLoopOp = firOpBuilder.create<mlir::omp::WsLoopOp>(
1051 currentLocation, lowerBound, upperBound, step, linearVars, linearStepVars,
1052 reductionVars,
1053 reductionDeclSymbols.empty()
1054 ? nullptr
1055 : mlir::ArrayAttr::get(firOpBuilder.getContext(),
1056 reductionDeclSymbols),
1057 scheduleClauseOperand.dyn_cast_or_null<omp::ClauseScheduleKindAttr>(),
1058 scheduleChunkClauseOperand, /*schedule_modifiers=*/nullptr,
1059 /*simd_modifier=*/nullptr,
1060 noWaitClauseOperand.dyn_cast_or_null<UnitAttr>(),
1061 orderedClauseOperand.dyn_cast_or_null<IntegerAttr>(),
1062 orderClauseOperand.dyn_cast_or_null<omp::ClauseOrderKindAttr>(),
1063 /*inclusive=*/firOpBuilder.getUnitAttr());
1064
1065 // Handle attribute based clauses.
1066 for (const Fortran::parser::OmpClause &clause : loopOpClauseList.v) {
1067 if (const auto &orderedClause =
1068 std::get_if<Fortran::parser::OmpClause::Ordered>(&clause.u)) {
1069 if (orderedClause->v.has_value()) {
1070 const auto *expr = Fortran::semantics::GetExpr(orderedClause->v);
1071 const std::optional<std::int64_t> orderedClauseValue =
1072 Fortran::evaluate::ToInt64(*expr);
1073 wsLoopOp.ordered_valAttr(
1074 firOpBuilder.getI64IntegerAttr(*orderedClauseValue));
1075 } else {
1076 wsLoopOp.ordered_valAttr(firOpBuilder.getI64IntegerAttr(0));
1077 }
1078 } else if (const auto &scheduleClause =
1079 std::get_if<Fortran::parser::OmpClause::Schedule>(
1080 &clause.u)) {
1081 mlir::MLIRContext *context = firOpBuilder.getContext();
1082 const auto &scheduleType = scheduleClause->v;
1083 const auto &scheduleKind =
1084 std::get<Fortran::parser::OmpScheduleClause::ScheduleType>(
1085 scheduleType.t);
1086 switch (scheduleKind) {
1087 case Fortran::parser::OmpScheduleClause::ScheduleType::Static:
1088 wsLoopOp.schedule_valAttr(omp::ClauseScheduleKindAttr::get(
1089 context, omp::ClauseScheduleKind::Static));
1090 break;
1091 case Fortran::parser::OmpScheduleClause::ScheduleType::Dynamic:
1092 wsLoopOp.schedule_valAttr(omp::ClauseScheduleKindAttr::get(
1093 context, omp::ClauseScheduleKind::Dynamic));
1094 break;
1095 case Fortran::parser::OmpScheduleClause::ScheduleType::Guided:
1096 wsLoopOp.schedule_valAttr(omp::ClauseScheduleKindAttr::get(
1097 context, omp::ClauseScheduleKind::Guided));
1098 break;
1099 case Fortran::parser::OmpScheduleClause::ScheduleType::Auto:
1100 wsLoopOp.schedule_valAttr(omp::ClauseScheduleKindAttr::get(
1101 context, omp::ClauseScheduleKind::Auto));
1102 break;
1103 case Fortran::parser::OmpScheduleClause::ScheduleType::Runtime:
1104 wsLoopOp.schedule_valAttr(omp::ClauseScheduleKindAttr::get(
1105 context, omp::ClauseScheduleKind::Runtime));
1106 break;
1107 }
1108 mlir::omp::ScheduleModifier scheduleModifier =
1109 getScheduleModifier(scheduleClause->v);
1110 if (scheduleModifier != mlir::omp::ScheduleModifier::none)
1111 wsLoopOp.schedule_modifierAttr(
1112 omp::ScheduleModifierAttr::get(context, scheduleModifier));
1113 if (getSIMDModifier(scheduleClause->v) !=
1114 mlir::omp::ScheduleModifier::none)
1115 wsLoopOp.simd_modifierAttr(firOpBuilder.getUnitAttr());
1116 }
1117 }
1118 // In FORTRAN `nowait` clause occur at the end of `omp do` directive.
1119 // i.e
1120 // !$omp do
1121 // <...>
1122 // !$omp end do nowait
1123 if (const auto &endClauseList =
1124 std::get<std::optional<Fortran::parser::OmpEndLoopDirective>>(
1125 loopConstruct.t)) {
1126 const auto &clauseList =
1127 std::get<Fortran::parser::OmpClauseList>((*endClauseList).t);
1128 for (const Fortran::parser::OmpClause &clause : clauseList.v)
1129 if (std::get_if<Fortran::parser::OmpClause::Nowait>(&clause.u))
1130 wsLoopOp.nowaitAttr(firOpBuilder.getUnitAttr());
1131 }
1132
1133 createBodyOfOp<omp::WsLoopOp>(wsLoopOp, converter, currentLocation, eval,
1134 &loopOpClauseList, iv);
1135 }
1136
1137 static void
genOMP(Fortran::lower::AbstractConverter & converter,Fortran::lower::pft::Evaluation & eval,const Fortran::parser::OpenMPCriticalConstruct & criticalConstruct)1138 genOMP(Fortran::lower::AbstractConverter &converter,
1139 Fortran::lower::pft::Evaluation &eval,
1140 const Fortran::parser::OpenMPCriticalConstruct &criticalConstruct) {
1141 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
1142 mlir::Location currentLocation = converter.getCurrentLocation();
1143 std::string name;
1144 const Fortran::parser::OmpCriticalDirective &cd =
1145 std::get<Fortran::parser::OmpCriticalDirective>(criticalConstruct.t);
1146 if (std::get<std::optional<Fortran::parser::Name>>(cd.t).has_value()) {
1147 name =
1148 std::get<std::optional<Fortran::parser::Name>>(cd.t).value().ToString();
1149 }
1150
1151 uint64_t hint = 0;
1152 const auto &clauseList = std::get<Fortran::parser::OmpClauseList>(cd.t);
1153 for (const Fortran::parser::OmpClause &clause : clauseList.v)
1154 if (auto hintClause =
1155 std::get_if<Fortran::parser::OmpClause::Hint>(&clause.u)) {
1156 const auto *expr = Fortran::semantics::GetExpr(hintClause->v);
1157 hint = *Fortran::evaluate::ToInt64(*expr);
1158 break;
1159 }
1160
1161 mlir::omp::CriticalOp criticalOp = [&]() {
1162 if (name.empty()) {
1163 return firOpBuilder.create<mlir::omp::CriticalOp>(currentLocation,
1164 FlatSymbolRefAttr());
1165 } else {
1166 mlir::ModuleOp module = firOpBuilder.getModule();
1167 mlir::OpBuilder modBuilder(module.getBodyRegion());
1168 auto global = module.lookupSymbol<mlir::omp::CriticalDeclareOp>(name);
1169 if (!global)
1170 global = modBuilder.create<mlir::omp::CriticalDeclareOp>(
1171 currentLocation, name, hint);
1172 return firOpBuilder.create<mlir::omp::CriticalOp>(
1173 currentLocation, mlir::FlatSymbolRefAttr::get(
1174 firOpBuilder.getContext(), global.sym_name()));
1175 }
1176 }();
1177 createBodyOfOp<omp::CriticalOp>(criticalOp, converter, currentLocation, eval);
1178 }
1179
1180 static void
genOMP(Fortran::lower::AbstractConverter & converter,Fortran::lower::pft::Evaluation & eval,const Fortran::parser::OpenMPSectionConstruct & sectionConstruct)1181 genOMP(Fortran::lower::AbstractConverter &converter,
1182 Fortran::lower::pft::Evaluation &eval,
1183 const Fortran::parser::OpenMPSectionConstruct §ionConstruct) {
1184
1185 auto &firOpBuilder = converter.getFirOpBuilder();
1186 auto currentLocation = converter.getCurrentLocation();
1187 mlir::omp::SectionOp sectionOp =
1188 firOpBuilder.create<mlir::omp::SectionOp>(currentLocation);
1189 createBodyOfOp<omp::SectionOp>(sectionOp, converter, currentLocation, eval);
1190 }
1191
1192 // TODO: Add support for reduction
1193 static void
genOMP(Fortran::lower::AbstractConverter & converter,Fortran::lower::pft::Evaluation & eval,const Fortran::parser::OpenMPSectionsConstruct & sectionsConstruct)1194 genOMP(Fortran::lower::AbstractConverter &converter,
1195 Fortran::lower::pft::Evaluation &eval,
1196 const Fortran::parser::OpenMPSectionsConstruct §ionsConstruct) {
1197 auto &firOpBuilder = converter.getFirOpBuilder();
1198 auto currentLocation = converter.getCurrentLocation();
1199 SmallVector<Value> reductionVars, allocateOperands, allocatorOperands;
1200 mlir::UnitAttr noWaitClauseOperand;
1201 const auto §ionsClauseList = std::get<Fortran::parser::OmpClauseList>(
1202 std::get<Fortran::parser::OmpBeginSectionsDirective>(sectionsConstruct.t)
1203 .t);
1204 for (const Fortran::parser::OmpClause &clause : sectionsClauseList.v) {
1205
1206 // Reduction Clause
1207 if (std::get_if<Fortran::parser::OmpClause::Reduction>(&clause.u)) {
1208 TODO(currentLocation, "OMPC_Reduction");
1209
1210 // Allocate clause
1211 } else if (const auto &allocateClause =
1212 std::get_if<Fortran::parser::OmpClause::Allocate>(
1213 &clause.u)) {
1214 genAllocateClause(converter, allocateClause->v, allocatorOperands,
1215 allocateOperands);
1216 }
1217 }
1218 const auto &endSectionsClauseList =
1219 std::get<Fortran::parser::OmpEndSectionsDirective>(sectionsConstruct.t);
1220 const auto &clauseList =
1221 std::get<Fortran::parser::OmpClauseList>(endSectionsClauseList.t);
1222 for (const auto &clause : clauseList.v) {
1223 // Nowait clause
1224 if (std::get_if<Fortran::parser::OmpClause::Nowait>(&clause.u)) {
1225 noWaitClauseOperand = firOpBuilder.getUnitAttr();
1226 }
1227 }
1228
1229 llvm::omp::Directive dir =
1230 std::get<Fortran::parser::OmpSectionsDirective>(
1231 std::get<Fortran::parser::OmpBeginSectionsDirective>(
1232 sectionsConstruct.t)
1233 .t)
1234 .v;
1235
1236 // Parallel Sections Construct
1237 if (dir == llvm::omp::Directive::OMPD_parallel_sections) {
1238 createCombinedParallelOp<Fortran::parser::OmpBeginSectionsDirective>(
1239 converter, eval,
1240 std::get<Fortran::parser::OmpBeginSectionsDirective>(
1241 sectionsConstruct.t));
1242 auto sectionsOp = firOpBuilder.create<mlir::omp::SectionsOp>(
1243 currentLocation, /*reduction_vars*/ ValueRange(),
1244 /*reductions=*/nullptr, allocateOperands, allocatorOperands,
1245 /*nowait=*/nullptr);
1246 createBodyOfOp(sectionsOp, converter, currentLocation, eval);
1247
1248 // Sections Construct
1249 } else if (dir == llvm::omp::Directive::OMPD_sections) {
1250 auto sectionsOp = firOpBuilder.create<mlir::omp::SectionsOp>(
1251 currentLocation, reductionVars, /*reductions = */ nullptr,
1252 allocateOperands, allocatorOperands, noWaitClauseOperand);
1253 createBodyOfOp<omp::SectionsOp>(sectionsOp, converter, currentLocation,
1254 eval);
1255 }
1256 }
1257
genOmpAtomicHintAndMemoryOrderClauses(Fortran::lower::AbstractConverter & converter,const Fortran::parser::OmpAtomicClauseList & clauseList,mlir::IntegerAttr & hint,mlir::omp::ClauseMemoryOrderKindAttr & memory_order)1258 static void genOmpAtomicHintAndMemoryOrderClauses(
1259 Fortran::lower::AbstractConverter &converter,
1260 const Fortran::parser::OmpAtomicClauseList &clauseList,
1261 mlir::IntegerAttr &hint,
1262 mlir::omp::ClauseMemoryOrderKindAttr &memory_order) {
1263 auto &firOpBuilder = converter.getFirOpBuilder();
1264 for (const auto &clause : clauseList.v) {
1265 if (auto ompClause = std::get_if<Fortran::parser::OmpClause>(&clause.u)) {
1266 if (auto hintClause =
1267 std::get_if<Fortran::parser::OmpClause::Hint>(&ompClause->u)) {
1268 const auto *expr = Fortran::semantics::GetExpr(hintClause->v);
1269 uint64_t hintExprValue = *Fortran::evaluate::ToInt64(*expr);
1270 hint = firOpBuilder.getI64IntegerAttr(hintExprValue);
1271 }
1272 } else if (auto ompMemoryOrderClause =
1273 std::get_if<Fortran::parser::OmpMemoryOrderClause>(
1274 &clause.u)) {
1275 if (std::get_if<Fortran::parser::OmpClause::Acquire>(
1276 &ompMemoryOrderClause->v.u)) {
1277 memory_order = mlir::omp::ClauseMemoryOrderKindAttr::get(
1278 firOpBuilder.getContext(), omp::ClauseMemoryOrderKind::Acquire);
1279 } else if (std::get_if<Fortran::parser::OmpClause::Relaxed>(
1280 &ompMemoryOrderClause->v.u)) {
1281 memory_order = mlir::omp::ClauseMemoryOrderKindAttr::get(
1282 firOpBuilder.getContext(), omp::ClauseMemoryOrderKind::Relaxed);
1283 } else if (std::get_if<Fortran::parser::OmpClause::SeqCst>(
1284 &ompMemoryOrderClause->v.u)) {
1285 memory_order = mlir::omp::ClauseMemoryOrderKindAttr::get(
1286 firOpBuilder.getContext(), omp::ClauseMemoryOrderKind::Seq_cst);
1287 } else if (std::get_if<Fortran::parser::OmpClause::Release>(
1288 &ompMemoryOrderClause->v.u)) {
1289 memory_order = mlir::omp::ClauseMemoryOrderKindAttr::get(
1290 firOpBuilder.getContext(), omp::ClauseMemoryOrderKind::Release);
1291 }
1292 }
1293 }
1294 }
1295
genOmpAtomicUpdateStatement(Fortran::lower::AbstractConverter & converter,Fortran::lower::pft::Evaluation & eval,const Fortran::parser::Variable & assignmentStmtVariable,const Fortran::parser::Expr & assignmentStmtExpr,const Fortran::parser::OmpAtomicClauseList * leftHandClauseList,const Fortran::parser::OmpAtomicClauseList * rightHandClauseList)1296 static void genOmpAtomicUpdateStatement(
1297 Fortran::lower::AbstractConverter &converter,
1298 Fortran::lower::pft::Evaluation &eval,
1299 const Fortran::parser::Variable &assignmentStmtVariable,
1300 const Fortran::parser::Expr &assignmentStmtExpr,
1301 const Fortran::parser::OmpAtomicClauseList *leftHandClauseList,
1302 const Fortran::parser::OmpAtomicClauseList *rightHandClauseList) {
1303 // Generate `omp.atomic.update` operation for atomic assignment statements
1304 auto &firOpBuilder = converter.getFirOpBuilder();
1305 auto currentLocation = converter.getCurrentLocation();
1306 Fortran::lower::StatementContext stmtCtx;
1307
1308 mlir::Value address = fir::getBase(converter.genExprAddr(
1309 *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx));
1310 // If no hint clause is specified, the effect is as if
1311 // hint(omp_sync_hint_none) had been specified.
1312 mlir::IntegerAttr hint = nullptr;
1313 mlir::omp::ClauseMemoryOrderKindAttr memory_order = nullptr;
1314 if (leftHandClauseList)
1315 genOmpAtomicHintAndMemoryOrderClauses(converter, *leftHandClauseList, hint,
1316 memory_order);
1317 if (rightHandClauseList)
1318 genOmpAtomicHintAndMemoryOrderClauses(converter, *rightHandClauseList, hint,
1319 memory_order);
1320 auto atomicUpdateOp = firOpBuilder.create<mlir::omp::AtomicUpdateOp>(
1321 currentLocation, address, hint, memory_order);
1322
1323 //// Generate body of Atomic Update operation
1324 // If an argument for the region is provided then create the block with that
1325 // argument. Also update the symbol's address with the argument mlir value.
1326 mlir::Type varType =
1327 fir::getBase(
1328 converter.genExprValue(
1329 *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx))
1330 .getType();
1331 SmallVector<Type> varTys = {varType};
1332 SmallVector<Location> locs = {currentLocation};
1333 firOpBuilder.createBlock(&atomicUpdateOp.getRegion(), {}, varTys, locs);
1334 mlir::Value val =
1335 fir::getBase(atomicUpdateOp.getRegion().front().getArgument(0));
1336 auto varDesignator =
1337 std::get_if<Fortran::common::Indirection<Fortran::parser::Designator>>(
1338 &assignmentStmtVariable.u);
1339 assert(varDesignator && "Variable designator for atomic update assignment "
1340 "statement does not exist");
1341 const auto *name = getDesignatorNameIfDataRef(varDesignator->value());
1342 assert(name && name->symbol &&
1343 "No symbol attached to atomic update variable");
1344 converter.bindSymbol(*name->symbol, val);
1345 // Set the insert for the terminator operation to go at the end of the
1346 // block.
1347 mlir::Block &block = atomicUpdateOp.getRegion().back();
1348 firOpBuilder.setInsertionPointToEnd(&block);
1349
1350 mlir::Value result = fir::getBase(converter.genExprValue(
1351 *Fortran::semantics::GetExpr(assignmentStmtExpr), stmtCtx));
1352 // Insert the terminator: YieldOp.
1353 firOpBuilder.create<mlir::omp::YieldOp>(currentLocation, result);
1354 // Reset the insert point to before the terminator.
1355 firOpBuilder.setInsertionPointToStart(&block);
1356 }
1357
1358 static void
genOmpAtomicWrite(Fortran::lower::AbstractConverter & converter,Fortran::lower::pft::Evaluation & eval,const Fortran::parser::OmpAtomicWrite & atomicWrite)1359 genOmpAtomicWrite(Fortran::lower::AbstractConverter &converter,
1360 Fortran::lower::pft::Evaluation &eval,
1361 const Fortran::parser::OmpAtomicWrite &atomicWrite) {
1362 auto &firOpBuilder = converter.getFirOpBuilder();
1363 auto currentLocation = converter.getCurrentLocation();
1364 // Get the value and address of atomic write operands.
1365 const Fortran::parser::OmpAtomicClauseList &rightHandClauseList =
1366 std::get<2>(atomicWrite.t);
1367 const Fortran::parser::OmpAtomicClauseList &leftHandClauseList =
1368 std::get<0>(atomicWrite.t);
1369 const auto &assignmentStmtExpr =
1370 std::get<Fortran::parser::Expr>(std::get<3>(atomicWrite.t).statement.t);
1371 const auto &assignmentStmtVariable = std::get<Fortran::parser::Variable>(
1372 std::get<3>(atomicWrite.t).statement.t);
1373 Fortran::lower::StatementContext stmtCtx;
1374 mlir::Value value = fir::getBase(converter.genExprValue(
1375 *Fortran::semantics::GetExpr(assignmentStmtExpr), stmtCtx));
1376 mlir::Value address = fir::getBase(converter.genExprAddr(
1377 *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx));
1378 // If no hint clause is specified, the effect is as if
1379 // hint(omp_sync_hint_none) had been specified.
1380 mlir::IntegerAttr hint = nullptr;
1381 mlir::omp::ClauseMemoryOrderKindAttr memory_order = nullptr;
1382 genOmpAtomicHintAndMemoryOrderClauses(converter, leftHandClauseList, hint,
1383 memory_order);
1384 genOmpAtomicHintAndMemoryOrderClauses(converter, rightHandClauseList, hint,
1385 memory_order);
1386 firOpBuilder.create<mlir::omp::AtomicWriteOp>(currentLocation, address, value,
1387 hint, memory_order);
1388 }
1389
genOmpAtomicRead(Fortran::lower::AbstractConverter & converter,Fortran::lower::pft::Evaluation & eval,const Fortran::parser::OmpAtomicRead & atomicRead)1390 static void genOmpAtomicRead(Fortran::lower::AbstractConverter &converter,
1391 Fortran::lower::pft::Evaluation &eval,
1392 const Fortran::parser::OmpAtomicRead &atomicRead) {
1393 auto &firOpBuilder = converter.getFirOpBuilder();
1394 auto currentLocation = converter.getCurrentLocation();
1395 // Get the address of atomic read operands.
1396 const Fortran::parser::OmpAtomicClauseList &rightHandClauseList =
1397 std::get<2>(atomicRead.t);
1398 const Fortran::parser::OmpAtomicClauseList &leftHandClauseList =
1399 std::get<0>(atomicRead.t);
1400 const auto &assignmentStmtExpr =
1401 std::get<Fortran::parser::Expr>(std::get<3>(atomicRead.t).statement.t);
1402 const auto &assignmentStmtVariable = std::get<Fortran::parser::Variable>(
1403 std::get<3>(atomicRead.t).statement.t);
1404 Fortran::lower::StatementContext stmtCtx;
1405 mlir::Value from_address = fir::getBase(converter.genExprAddr(
1406 *Fortran::semantics::GetExpr(assignmentStmtExpr), stmtCtx));
1407 mlir::Value to_address = fir::getBase(converter.genExprAddr(
1408 *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx));
1409 // If no hint clause is specified, the effect is as if
1410 // hint(omp_sync_hint_none) had been specified.
1411 mlir::IntegerAttr hint = nullptr;
1412 mlir::omp::ClauseMemoryOrderKindAttr memory_order = nullptr;
1413 genOmpAtomicHintAndMemoryOrderClauses(converter, leftHandClauseList, hint,
1414 memory_order);
1415 genOmpAtomicHintAndMemoryOrderClauses(converter, rightHandClauseList, hint,
1416 memory_order);
1417 firOpBuilder.create<mlir::omp::AtomicReadOp>(currentLocation, from_address,
1418 to_address, hint, memory_order);
1419 }
1420
1421 static void
genOmpAtomicUpdate(Fortran::lower::AbstractConverter & converter,Fortran::lower::pft::Evaluation & eval,const Fortran::parser::OmpAtomicUpdate & atomicUpdate)1422 genOmpAtomicUpdate(Fortran::lower::AbstractConverter &converter,
1423 Fortran::lower::pft::Evaluation &eval,
1424 const Fortran::parser::OmpAtomicUpdate &atomicUpdate) {
1425 const Fortran::parser::OmpAtomicClauseList &rightHandClauseList =
1426 std::get<2>(atomicUpdate.t);
1427 const Fortran::parser::OmpAtomicClauseList &leftHandClauseList =
1428 std::get<0>(atomicUpdate.t);
1429 const auto &assignmentStmtExpr =
1430 std::get<Fortran::parser::Expr>(std::get<3>(atomicUpdate.t).statement.t);
1431 const auto &assignmentStmtVariable = std::get<Fortran::parser::Variable>(
1432 std::get<3>(atomicUpdate.t).statement.t);
1433
1434 genOmpAtomicUpdateStatement(converter, eval, assignmentStmtVariable,
1435 assignmentStmtExpr, &leftHandClauseList,
1436 &rightHandClauseList);
1437 }
1438
genOmpAtomic(Fortran::lower::AbstractConverter & converter,Fortran::lower::pft::Evaluation & eval,const Fortran::parser::OmpAtomic & atomicConstruct)1439 static void genOmpAtomic(Fortran::lower::AbstractConverter &converter,
1440 Fortran::lower::pft::Evaluation &eval,
1441 const Fortran::parser::OmpAtomic &atomicConstruct) {
1442 const Fortran::parser::OmpAtomicClauseList &atomicClauseList =
1443 std::get<Fortran::parser::OmpAtomicClauseList>(atomicConstruct.t);
1444 const auto &assignmentStmtExpr = std::get<Fortran::parser::Expr>(
1445 std::get<Fortran::parser::Statement<Fortran::parser::AssignmentStmt>>(
1446 atomicConstruct.t)
1447 .statement.t);
1448 const auto &assignmentStmtVariable = std::get<Fortran::parser::Variable>(
1449 std::get<Fortran::parser::Statement<Fortran::parser::AssignmentStmt>>(
1450 atomicConstruct.t)
1451 .statement.t);
1452 // If atomic-clause is not present on the construct, the behaviour is as if
1453 // the update clause is specified
1454 genOmpAtomicUpdateStatement(converter, eval, assignmentStmtVariable,
1455 assignmentStmtExpr, &atomicClauseList, nullptr);
1456 }
1457
1458 static void
genOMP(Fortran::lower::AbstractConverter & converter,Fortran::lower::pft::Evaluation & eval,const Fortran::parser::OpenMPAtomicConstruct & atomicConstruct)1459 genOMP(Fortran::lower::AbstractConverter &converter,
1460 Fortran::lower::pft::Evaluation &eval,
1461 const Fortran::parser::OpenMPAtomicConstruct &atomicConstruct) {
1462 std::visit(Fortran::common::visitors{
1463 [&](const Fortran::parser::OmpAtomicRead &atomicRead) {
1464 genOmpAtomicRead(converter, eval, atomicRead);
1465 },
1466 [&](const Fortran::parser::OmpAtomicWrite &atomicWrite) {
1467 genOmpAtomicWrite(converter, eval, atomicWrite);
1468 },
1469 [&](const Fortran::parser::OmpAtomic &atomicConstruct) {
1470 genOmpAtomic(converter, eval, atomicConstruct);
1471 },
1472 [&](const Fortran::parser::OmpAtomicUpdate &atomicUpdate) {
1473 genOmpAtomicUpdate(converter, eval, atomicUpdate);
1474 },
1475 [&](const auto &) {
1476 TODO(converter.getCurrentLocation(), "Atomic capture");
1477 },
1478 },
1479 atomicConstruct.u);
1480 }
1481
genOpenMPConstruct(Fortran::lower::AbstractConverter & converter,Fortran::lower::pft::Evaluation & eval,const Fortran::parser::OpenMPConstruct & ompConstruct)1482 void Fortran::lower::genOpenMPConstruct(
1483 Fortran::lower::AbstractConverter &converter,
1484 Fortran::lower::pft::Evaluation &eval,
1485 const Fortran::parser::OpenMPConstruct &ompConstruct) {
1486
1487 std::visit(
1488 common::visitors{
1489 [&](const Fortran::parser::OpenMPStandaloneConstruct
1490 &standaloneConstruct) {
1491 genOMP(converter, eval, standaloneConstruct);
1492 },
1493 [&](const Fortran::parser::OpenMPSectionsConstruct
1494 §ionsConstruct) {
1495 genOMP(converter, eval, sectionsConstruct);
1496 },
1497 [&](const Fortran::parser::OpenMPSectionConstruct §ionConstruct) {
1498 genOMP(converter, eval, sectionConstruct);
1499 },
1500 [&](const Fortran::parser::OpenMPLoopConstruct &loopConstruct) {
1501 genOMP(converter, eval, loopConstruct);
1502 },
1503 [&](const Fortran::parser::OpenMPDeclarativeAllocate
1504 &execAllocConstruct) {
1505 TODO(converter.getCurrentLocation(), "OpenMPDeclarativeAllocate");
1506 },
1507 [&](const Fortran::parser::OpenMPExecutableAllocate
1508 &execAllocConstruct) {
1509 TODO(converter.getCurrentLocation(), "OpenMPExecutableAllocate");
1510 },
1511 [&](const Fortran::parser::OpenMPBlockConstruct &blockConstruct) {
1512 genOMP(converter, eval, blockConstruct);
1513 },
1514 [&](const Fortran::parser::OpenMPAtomicConstruct &atomicConstruct) {
1515 genOMP(converter, eval, atomicConstruct);
1516 },
1517 [&](const Fortran::parser::OpenMPCriticalConstruct
1518 &criticalConstruct) {
1519 genOMP(converter, eval, criticalConstruct);
1520 },
1521 },
1522 ompConstruct.u);
1523 }
1524
genThreadprivateOp(Fortran::lower::AbstractConverter & converter,const Fortran::lower::pft::Variable & var)1525 void Fortran::lower::genThreadprivateOp(
1526 Fortran::lower::AbstractConverter &converter,
1527 const Fortran::lower::pft::Variable &var) {
1528 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
1529 mlir::Location currentLocation = converter.getCurrentLocation();
1530
1531 const Fortran::semantics::Symbol &sym = var.getSymbol();
1532 mlir::Value symThreadprivateValue;
1533 if (const Fortran::semantics::Symbol *common =
1534 Fortran::semantics::FindCommonBlockContaining(sym.GetUltimate())) {
1535 mlir::Value commonValue = converter.getSymbolAddress(*common);
1536 if (mlir::isa<mlir::omp::ThreadprivateOp>(commonValue.getDefiningOp())) {
1537 // Generate ThreadprivateOp for a common block instead of its members and
1538 // only do it once for a common block.
1539 return;
1540 }
1541 // Generate ThreadprivateOp and rebind the common block.
1542 mlir::Value commonThreadprivateValue =
1543 firOpBuilder.create<mlir::omp::ThreadprivateOp>(
1544 currentLocation, commonValue.getType(), commonValue);
1545 converter.bindSymbol(*common, commonThreadprivateValue);
1546 // Generate the threadprivate value for the common block member.
1547 symThreadprivateValue =
1548 genCommonBlockMember(converter, sym, commonThreadprivateValue);
1549 } else {
1550 mlir::Value symValue = converter.getSymbolAddress(sym);
1551 symThreadprivateValue = firOpBuilder.create<mlir::omp::ThreadprivateOp>(
1552 currentLocation, symValue.getType(), symValue);
1553 }
1554
1555 fir::ExtendedValue sexv = converter.getSymbolExtendedValue(sym);
1556 fir::ExtendedValue symThreadprivateExv =
1557 getExtendedValue(sexv, symThreadprivateValue);
1558 converter.bindSymbol(sym, symThreadprivateExv);
1559 }
1560
genOpenMPDeclarativeConstruct(Fortran::lower::AbstractConverter & converter,Fortran::lower::pft::Evaluation & eval,const Fortran::parser::OpenMPDeclarativeConstruct & ompDeclConstruct)1561 void Fortran::lower::genOpenMPDeclarativeConstruct(
1562 Fortran::lower::AbstractConverter &converter,
1563 Fortran::lower::pft::Evaluation &eval,
1564 const Fortran::parser::OpenMPDeclarativeConstruct &ompDeclConstruct) {
1565
1566 std::visit(
1567 common::visitors{
1568 [&](const Fortran::parser::OpenMPDeclarativeAllocate
1569 &declarativeAllocate) {
1570 TODO(converter.getCurrentLocation(), "OpenMPDeclarativeAllocate");
1571 },
1572 [&](const Fortran::parser::OpenMPDeclareReductionConstruct
1573 &declareReductionConstruct) {
1574 TODO(converter.getCurrentLocation(),
1575 "OpenMPDeclareReductionConstruct");
1576 },
1577 [&](const Fortran::parser::OpenMPDeclareSimdConstruct
1578 &declareSimdConstruct) {
1579 TODO(converter.getCurrentLocation(), "OpenMPDeclareSimdConstruct");
1580 },
1581 [&](const Fortran::parser::OpenMPDeclareTargetConstruct
1582 &declareTargetConstruct) {
1583 TODO(converter.getCurrentLocation(),
1584 "OpenMPDeclareTargetConstruct");
1585 },
1586 [&](const Fortran::parser::OpenMPThreadprivate &threadprivate) {
1587 // The directive is lowered when instantiating the variable to
1588 // support the case of threadprivate variable declared in module.
1589 },
1590 },
1591 ompDeclConstruct.u);
1592 }
1593
1594 // Generate an OpenMP reduction operation. This implementation finds the chain :
1595 // load reduction var -> reduction_operation -> store reduction var and replaces
1596 // it with the reduction operation.
1597 // TODO: Currently assumes it is an integer addition reduction. Generalize this
1598 // for various reduction operation types.
1599 // TODO: Generate the reduction operation during lowering instead of creating
1600 // and removing operations since this is not a robust approach. Also, removing
1601 // ops in the builder (instead of a rewriter) is probably not the best approach.
genOpenMPReduction(Fortran::lower::AbstractConverter & converter,const Fortran::parser::OmpClauseList & clauseList)1602 void Fortran::lower::genOpenMPReduction(
1603 Fortran::lower::AbstractConverter &converter,
1604 const Fortran::parser::OmpClauseList &clauseList) {
1605 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
1606
1607 for (const auto &clause : clauseList.v) {
1608 if (const auto &reductionClause =
1609 std::get_if<Fortran::parser::OmpClause::Reduction>(&clause.u)) {
1610 const auto &redOperator{std::get<Fortran::parser::OmpReductionOperator>(
1611 reductionClause->v.t)};
1612 const auto &objectList{
1613 std::get<Fortran::parser::OmpObjectList>(reductionClause->v.t)};
1614 if (auto reductionOp =
1615 std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
1616 const auto &intrinsicOp{
1617 std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>(
1618 reductionOp->u)};
1619 if (intrinsicOp !=
1620 Fortran::parser::DefinedOperator::IntrinsicOperator::Add)
1621 continue;
1622 for (const auto &ompObject : objectList.v) {
1623 if (const auto *name{
1624 Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
1625 if (const auto *symbol{name->symbol}) {
1626 mlir::Value symVal = converter.getSymbolAddress(*symbol);
1627 mlir::Type redType =
1628 symVal.getType().cast<fir::ReferenceType>().getEleTy();
1629 if (!redType.isIntOrIndex())
1630 continue;
1631 for (mlir::OpOperand &use1 : symVal.getUses()) {
1632 if (auto load = mlir::dyn_cast<fir::LoadOp>(use1.getOwner())) {
1633 mlir::Value loadVal = load.getRes();
1634 for (mlir::OpOperand &use2 : loadVal.getUses()) {
1635 if (auto add = mlir::dyn_cast<mlir::arith::AddIOp>(
1636 use2.getOwner())) {
1637 mlir::Value addRes = add.getResult();
1638 for (mlir::OpOperand &use3 : addRes.getUses()) {
1639 if (auto store =
1640 mlir::dyn_cast<fir::StoreOp>(use3.getOwner())) {
1641 if (store.getMemref() == symVal) {
1642 // Chain found! Now replace load->reduction->store
1643 // with the OpenMP reduction operation.
1644 mlir::OpBuilder::InsertPoint insertPtDel =
1645 firOpBuilder.saveInsertionPoint();
1646 firOpBuilder.setInsertionPoint(add);
1647 if (add.getLhs() == loadVal) {
1648 firOpBuilder.create<mlir::omp::ReductionOp>(
1649 add.getLoc(), add.getRhs(), symVal);
1650 } else {
1651 firOpBuilder.create<mlir::omp::ReductionOp>(
1652 add.getLoc(), add.getLhs(), symVal);
1653 }
1654 store.erase();
1655 add.erase();
1656 load.erase();
1657 firOpBuilder.restoreInsertionPoint(insertPtDel);
1658 }
1659 }
1660 }
1661 }
1662 }
1663 }
1664 }
1665 }
1666 }
1667 }
1668 }
1669 }
1670 }
1671 }
1672