1 //===-- OpenMP.cpp -- Open MP directive lowering --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "flang/Lower/OpenMP.h"
14 #include "flang/Common/idioms.h"
15 #include "flang/Lower/Bridge.h"
16 #include "flang/Lower/PFTBuilder.h"
17 #include "flang/Lower/StatementContext.h"
18 #include "flang/Lower/Todo.h"
19 #include "flang/Optimizer/Builder/BoxValue.h"
20 #include "flang/Optimizer/Builder/FIRBuilder.h"
21 #include "flang/Parser/parse-tree.h"
22 #include "flang/Semantics/tools.h"
23 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
24 #include "llvm/Frontend/OpenMP/OMPConstants.h"
25 
26 static const Fortran::parser::Name *
27 getDesignatorNameIfDataRef(const Fortran::parser::Designator &designator) {
28   const auto *dataRef = std::get_if<Fortran::parser::DataRef>(&designator.u);
29   return dataRef ? std::get_if<Fortran::parser::Name>(&dataRef->u) : nullptr;
30 }
31 
32 static void genObjectList(const Fortran::parser::OmpObjectList &objectList,
33                           Fortran::lower::AbstractConverter &converter,
34                           SmallVectorImpl<Value> &operands) {
35   for (const auto &ompObject : objectList.v) {
36     std::visit(
37         Fortran::common::visitors{
38             [&](const Fortran::parser::Designator &designator) {
39               if (const auto *name = getDesignatorNameIfDataRef(designator)) {
40                 const auto variable = converter.getSymbolAddress(*name->symbol);
41                 operands.push_back(variable);
42               }
43             },
44             [&](const Fortran::parser::Name &name) {
45               const auto variable = converter.getSymbolAddress(*name.symbol);
46               operands.push_back(variable);
47             }},
48         ompObject.u);
49   }
50 }
51 
52 template <typename Op>
53 static void createBodyOfOp(Op &op, fir::FirOpBuilder &firOpBuilder,
54                            mlir::Location &loc) {
55   firOpBuilder.createBlock(&op.getRegion());
56   auto &block = op.getRegion().back();
57   firOpBuilder.setInsertionPointToStart(&block);
58   // Ensure the block is well-formed.
59   firOpBuilder.create<mlir::omp::TerminatorOp>(loc);
60   // Reset the insertion point to the start of the first block.
61   firOpBuilder.setInsertionPointToStart(&block);
62 }
63 
64 static void genOMP(Fortran::lower::AbstractConverter &converter,
65                    Fortran::lower::pft::Evaluation &eval,
66                    const Fortran::parser::OpenMPSimpleStandaloneConstruct
67                        &simpleStandaloneConstruct) {
68   const auto &directive =
69       std::get<Fortran::parser::OmpSimpleStandaloneDirective>(
70           simpleStandaloneConstruct.t);
71   switch (directive.v) {
72   default:
73     break;
74   case llvm::omp::Directive::OMPD_barrier:
75     converter.getFirOpBuilder().create<mlir::omp::BarrierOp>(
76         converter.getCurrentLocation());
77     break;
78   case llvm::omp::Directive::OMPD_taskwait:
79     converter.getFirOpBuilder().create<mlir::omp::TaskwaitOp>(
80         converter.getCurrentLocation());
81     break;
82   case llvm::omp::Directive::OMPD_taskyield:
83     converter.getFirOpBuilder().create<mlir::omp::TaskyieldOp>(
84         converter.getCurrentLocation());
85     break;
86   case llvm::omp::Directive::OMPD_target_enter_data:
87     TODO(converter.getCurrentLocation(), "OMPD_target_enter_data");
88   case llvm::omp::Directive::OMPD_target_exit_data:
89     TODO(converter.getCurrentLocation(), "OMPD_target_exit_data");
90   case llvm::omp::Directive::OMPD_target_update:
91     TODO(converter.getCurrentLocation(), "OMPD_target_update");
92   case llvm::omp::Directive::OMPD_ordered:
93     TODO(converter.getCurrentLocation(), "OMPD_ordered");
94   }
95 }
96 
97 static void
98 genOMP(Fortran::lower::AbstractConverter &converter,
99        Fortran::lower::pft::Evaluation &eval,
100        const Fortran::parser::OpenMPStandaloneConstruct &standaloneConstruct) {
101   std::visit(
102       Fortran::common::visitors{
103           [&](const Fortran::parser::OpenMPSimpleStandaloneConstruct
104                   &simpleStandaloneConstruct) {
105             genOMP(converter, eval, simpleStandaloneConstruct);
106           },
107           [&](const Fortran::parser::OpenMPFlushConstruct &flushConstruct) {
108             SmallVector<Value, 4> operandRange;
109             if (const auto &ompObjectList =
110                     std::get<std::optional<Fortran::parser::OmpObjectList>>(
111                         flushConstruct.t))
112               genObjectList(*ompObjectList, converter, operandRange);
113             if (std::get<std::optional<
114                     std::list<Fortran::parser::OmpMemoryOrderClause>>>(
115                     flushConstruct.t))
116               TODO(converter.getCurrentLocation(),
117                    "Handle OmpMemoryOrderClause");
118             converter.getFirOpBuilder().create<mlir::omp::FlushOp>(
119                 converter.getCurrentLocation(), operandRange);
120           },
121           [&](const Fortran::parser::OpenMPCancelConstruct &cancelConstruct) {
122             TODO(converter.getCurrentLocation(), "OpenMPCancelConstruct");
123           },
124           [&](const Fortran::parser::OpenMPCancellationPointConstruct
125                   &cancellationPointConstruct) {
126             TODO(converter.getCurrentLocation(), "OpenMPCancelConstruct");
127           },
128       },
129       standaloneConstruct.u);
130 }
131 
132 static void
133 genOMP(Fortran::lower::AbstractConverter &converter,
134        Fortran::lower::pft::Evaluation &eval,
135        const Fortran::parser::OpenMPBlockConstruct &blockConstruct) {
136   const auto &beginBlockDirective =
137       std::get<Fortran::parser::OmpBeginBlockDirective>(blockConstruct.t);
138   const auto &blockDirective =
139       std::get<Fortran::parser::OmpBlockDirective>(beginBlockDirective.t);
140 
141   auto &firOpBuilder = converter.getFirOpBuilder();
142   auto currentLocation = converter.getCurrentLocation();
143   Fortran::lower::StatementContext stmtCtx;
144   llvm::ArrayRef<mlir::Type> argTy;
145   if (blockDirective.v == llvm::omp::OMPD_parallel) {
146 
147     mlir::Value ifClauseOperand, numThreadsClauseOperand;
148     Attribute procBindClauseOperand;
149 
150     const auto &parallelOpClauseList =
151         std::get<Fortran::parser::OmpClauseList>(beginBlockDirective.t);
152     for (const auto &clause : parallelOpClauseList.v) {
153       if (const auto &ifClause =
154               std::get_if<Fortran::parser::OmpClause::If>(&clause.u)) {
155         auto &expr =
156             std::get<Fortran::parser::ScalarLogicalExpr>(ifClause->v.t);
157         ifClauseOperand = fir::getBase(converter.genExprValue(
158             *Fortran::semantics::GetExpr(expr), stmtCtx));
159       } else if (const auto &numThreadsClause =
160                      std::get_if<Fortran::parser::OmpClause::NumThreads>(
161                          &clause.u)) {
162         // OMPIRBuilder expects `NUM_THREAD` clause as a `Value`.
163         numThreadsClauseOperand = fir::getBase(converter.genExprValue(
164             *Fortran::semantics::GetExpr(numThreadsClause->v), stmtCtx));
165       }
166       // TODO: Handle private, firstprivate, shared and copyin
167     }
168     // Create and insert the operation.
169     auto parallelOp = firOpBuilder.create<mlir::omp::ParallelOp>(
170         currentLocation, argTy, ifClauseOperand, numThreadsClauseOperand,
171         ValueRange(), ValueRange(),
172         procBindClauseOperand.dyn_cast_or_null<omp::ClauseProcBindKindAttr>());
173     // Handle attribute based clauses.
174     for (const auto &clause : parallelOpClauseList.v) {
175       // TODO: Handle default clause
176       if (const auto &procBindClause =
177               std::get_if<Fortran::parser::OmpClause::ProcBind>(&clause.u)) {
178         const auto &ompProcBindClause{procBindClause->v};
179         omp::ClauseProcBindKind pbKind;
180         switch (ompProcBindClause.v) {
181         case Fortran::parser::OmpProcBindClause::Type::Master:
182           pbKind = omp::ClauseProcBindKind::master;
183           break;
184         case Fortran::parser::OmpProcBindClause::Type::Close:
185           pbKind = omp::ClauseProcBindKind::close;
186           break;
187         case Fortran::parser::OmpProcBindClause::Type::Spread:
188           pbKind = omp::ClauseProcBindKind::spread;
189           break;
190         }
191         parallelOp.proc_bind_valAttr(omp::ClauseProcBindKindAttr::get(
192             firOpBuilder.getContext(), pbKind));
193       }
194     }
195     createBodyOfOp<omp::ParallelOp>(parallelOp, firOpBuilder, currentLocation);
196   } else if (blockDirective.v == llvm::omp::OMPD_master) {
197     auto masterOp =
198         firOpBuilder.create<mlir::omp::MasterOp>(currentLocation, argTy);
199     createBodyOfOp<omp::MasterOp>(masterOp, firOpBuilder, currentLocation);
200   }
201 }
202 
203 void Fortran::lower::genOpenMPConstruct(
204     Fortran::lower::AbstractConverter &converter,
205     Fortran::lower::pft::Evaluation &eval,
206     const Fortran::parser::OpenMPConstruct &ompConstruct) {
207 
208   std::visit(
209       common::visitors{
210           [&](const Fortran::parser::OpenMPStandaloneConstruct
211                   &standaloneConstruct) {
212             genOMP(converter, eval, standaloneConstruct);
213           },
214           [&](const Fortran::parser::OpenMPSectionsConstruct
215                   &sectionsConstruct) {
216             TODO(converter.getCurrentLocation(), "OpenMPSectionsConstruct");
217           },
218           [&](const Fortran::parser::OpenMPLoopConstruct &loopConstruct) {
219             TODO(converter.getCurrentLocation(), "OpenMPLoopConstruct");
220           },
221           [&](const Fortran::parser::OpenMPDeclarativeAllocate
222                   &execAllocConstruct) {
223             TODO(converter.getCurrentLocation(), "OpenMPDeclarativeAllocate");
224           },
225           [&](const Fortran::parser::OpenMPExecutableAllocate
226                   &execAllocConstruct) {
227             TODO(converter.getCurrentLocation(), "OpenMPExecutableAllocate");
228           },
229           [&](const Fortran::parser::OpenMPBlockConstruct &blockConstruct) {
230             genOMP(converter, eval, blockConstruct);
231           },
232           [&](const Fortran::parser::OpenMPAtomicConstruct &atomicConstruct) {
233             TODO(converter.getCurrentLocation(), "OpenMPAtomicConstruct");
234           },
235           [&](const Fortran::parser::OpenMPCriticalConstruct
236                   &criticalConstruct) {
237             TODO(converter.getCurrentLocation(), "OpenMPCriticalConstruct");
238           },
239       },
240       ompConstruct.u);
241 }
242