1 //===-- OpenMP.cpp -- Open MP directive lowering --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "flang/Lower/OpenMP.h"
14 #include "flang/Common/idioms.h"
15 #include "flang/Lower/Bridge.h"
16 #include "flang/Lower/FIRBuilder.h"
17 #include "flang/Lower/PFTBuilder.h"
18 #include "flang/Lower/Support/BoxValue.h"
19 #include "flang/Lower/Todo.h"
20 #include "flang/Parser/parse-tree.h"
21 #include "flang/Semantics/tools.h"
22 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
23 #include "llvm/Frontend/OpenMP/OMPConstants.h"
24 
25 static const Fortran::parser::Name *
26 getDesignatorNameIfDataRef(const Fortran::parser::Designator &designator) {
27   const auto *dataRef = std::get_if<Fortran::parser::DataRef>(&designator.u);
28   return dataRef ? std::get_if<Fortran::parser::Name>(&dataRef->u) : nullptr;
29 }
30 
31 static void genObjectList(const Fortran::parser::OmpObjectList &objectList,
32                           Fortran::lower::AbstractConverter &converter,
33                           SmallVectorImpl<Value> &operands) {
34   for (const auto &ompObject : objectList.v) {
35     std::visit(
36         Fortran::common::visitors{
37             [&](const Fortran::parser::Designator &designator) {
38               if (const auto *name = getDesignatorNameIfDataRef(designator)) {
39                 const auto variable = converter.getSymbolAddress(*name->symbol);
40                 operands.push_back(variable);
41               }
42             },
43             [&](const Fortran::parser::Name &name) {
44               const auto variable = converter.getSymbolAddress(*name.symbol);
45               operands.push_back(variable);
46             }},
47         ompObject.u);
48   }
49 }
50 
51 template <typename Op>
52 static void createBodyOfOp(Op &op, Fortran::lower::FirOpBuilder &firOpBuilder,
53                            mlir::Location &loc) {
54   firOpBuilder.createBlock(&op.getRegion());
55   auto &block = op.getRegion().back();
56   firOpBuilder.setInsertionPointToStart(&block);
57   // Ensure the block is well-formed.
58   firOpBuilder.create<mlir::omp::TerminatorOp>(loc);
59   // Reset the insertion point to the start of the first block.
60   firOpBuilder.setInsertionPointToStart(&block);
61 }
62 
63 static void genOMP(Fortran::lower::AbstractConverter &converter,
64                    Fortran::lower::pft::Evaluation &eval,
65                    const Fortran::parser::OpenMPSimpleStandaloneConstruct
66                        &simpleStandaloneConstruct) {
67   const auto &directive =
68       std::get<Fortran::parser::OmpSimpleStandaloneDirective>(
69           simpleStandaloneConstruct.t);
70   switch (directive.v) {
71   default:
72     break;
73   case llvm::omp::Directive::OMPD_barrier:
74     converter.getFirOpBuilder().create<mlir::omp::BarrierOp>(
75         converter.getCurrentLocation());
76     break;
77   case llvm::omp::Directive::OMPD_taskwait:
78     converter.getFirOpBuilder().create<mlir::omp::TaskwaitOp>(
79         converter.getCurrentLocation());
80     break;
81   case llvm::omp::Directive::OMPD_taskyield:
82     converter.getFirOpBuilder().create<mlir::omp::TaskyieldOp>(
83         converter.getCurrentLocation());
84     break;
85   case llvm::omp::Directive::OMPD_target_enter_data:
86     TODO("");
87   case llvm::omp::Directive::OMPD_target_exit_data:
88     TODO("");
89   case llvm::omp::Directive::OMPD_target_update:
90     TODO("");
91   case llvm::omp::Directive::OMPD_ordered:
92     TODO("");
93   }
94 }
95 
96 static void
97 genOMP(Fortran::lower::AbstractConverter &converter,
98        Fortran::lower::pft::Evaluation &eval,
99        const Fortran::parser::OpenMPStandaloneConstruct &standaloneConstruct) {
100   std::visit(
101       Fortran::common::visitors{
102           [&](const Fortran::parser::OpenMPSimpleStandaloneConstruct
103                   &simpleStandaloneConstruct) {
104             genOMP(converter, eval, simpleStandaloneConstruct);
105           },
106           [&](const Fortran::parser::OpenMPFlushConstruct &flushConstruct) {
107             SmallVector<Value, 4> operandRange;
108             if (const auto &ompObjectList =
109                     std::get<std::optional<Fortran::parser::OmpObjectList>>(
110                         flushConstruct.t))
111               genObjectList(*ompObjectList, converter, operandRange);
112             if (std::get<std::optional<
113                     std::list<Fortran::parser::OmpMemoryOrderClause>>>(
114                     flushConstruct.t))
115               TODO("Handle OmpMemoryOrderClause");
116             converter.getFirOpBuilder().create<mlir::omp::FlushOp>(
117                 converter.getCurrentLocation(), operandRange);
118           },
119           [&](const Fortran::parser::OpenMPCancelConstruct &cancelConstruct) {
120             TODO("");
121           },
122           [&](const Fortran::parser::OpenMPCancellationPointConstruct
123                   &cancellationPointConstruct) { TODO(""); },
124       },
125       standaloneConstruct.u);
126 }
127 
128 static void
129 genOMP(Fortran::lower::AbstractConverter &converter,
130        Fortran::lower::pft::Evaluation &eval,
131        const Fortran::parser::OpenMPBlockConstruct &blockConstruct) {
132   const auto &beginBlockDirective =
133       std::get<Fortran::parser::OmpBeginBlockDirective>(blockConstruct.t);
134   const auto &blockDirective =
135       std::get<Fortran::parser::OmpBlockDirective>(beginBlockDirective.t);
136 
137   auto &firOpBuilder = converter.getFirOpBuilder();
138   auto currentLocation = converter.getCurrentLocation();
139   llvm::ArrayRef<mlir::Type> argTy;
140   if (blockDirective.v == llvm::omp::OMPD_parallel) {
141 
142     mlir::Value ifClauseOperand, numThreadsClauseOperand;
143     SmallVector<Value, 4> privateClauseOperands, firstprivateClauseOperands,
144         sharedClauseOperands, copyinClauseOperands;
145     Attribute defaultClauseOperand, procBindClauseOperand;
146 
147     const auto &parallelOpClauseList =
148         std::get<Fortran::parser::OmpClauseList>(beginBlockDirective.t);
149     for (const auto &clause : parallelOpClauseList.v) {
150       if (const auto &ifClause =
151               std::get_if<Fortran::parser::OmpClause::If>(&clause.u)) {
152         auto &expr =
153             std::get<Fortran::parser::ScalarLogicalExpr>(ifClause->v.t);
154         ifClauseOperand = fir::getBase(
155             converter.genExprValue(*Fortran::semantics::GetExpr(expr)));
156       } else if (const auto &numThreadsClause =
157                      std::get_if<Fortran::parser::OmpClause::NumThreads>(
158                          &clause.u)) {
159         // OMPIRBuilder expects `NUM_THREAD` clause as a `Value`.
160         numThreadsClauseOperand = fir::getBase(converter.genExprValue(
161             *Fortran::semantics::GetExpr(numThreadsClause->v)));
162       } else if (const auto &privateClause =
163                      std::get_if<Fortran::parser::OmpClause::Private>(
164                          &clause.u)) {
165         const Fortran::parser::OmpObjectList &ompObjectList = privateClause->v;
166         genObjectList(ompObjectList, converter, privateClauseOperands);
167       } else if (const auto &firstprivateClause =
168                      std::get_if<Fortran::parser::OmpClause::Firstprivate>(
169                          &clause.u)) {
170         const Fortran::parser::OmpObjectList &ompObjectList =
171             firstprivateClause->v;
172         genObjectList(ompObjectList, converter, firstprivateClauseOperands);
173       } else if (const auto &sharedClause =
174                      std::get_if<Fortran::parser::OmpClause::Shared>(
175                          &clause.u)) {
176         const Fortran::parser::OmpObjectList &ompObjectList = sharedClause->v;
177         genObjectList(ompObjectList, converter, sharedClauseOperands);
178       } else if (const auto &copyinClause =
179                      std::get_if<Fortran::parser::OmpClause::Copyin>(
180                          &clause.u)) {
181         const Fortran::parser::OmpObjectList &ompObjectList = copyinClause->v;
182         genObjectList(ompObjectList, converter, copyinClauseOperands);
183       }
184     }
185     // Create and insert the operation.
186     auto parallelOp = firOpBuilder.create<mlir::omp::ParallelOp>(
187         currentLocation, argTy, ifClauseOperand, numThreadsClauseOperand,
188         defaultClauseOperand.dyn_cast_or_null<StringAttr>(),
189         privateClauseOperands, firstprivateClauseOperands, sharedClauseOperands,
190         copyinClauseOperands, ValueRange(), ValueRange(),
191         procBindClauseOperand.dyn_cast_or_null<StringAttr>());
192     // Handle attribute based clauses.
193     for (const auto &clause : parallelOpClauseList.v) {
194       if (const auto &defaultClause =
195               std::get_if<Fortran::parser::OmpClause::Default>(&clause.u)) {
196         const auto &ompDefaultClause{defaultClause->v};
197         switch (ompDefaultClause.v) {
198         case Fortran::parser::OmpDefaultClause::Type::Private:
199           parallelOp.default_valAttr(firOpBuilder.getStringAttr(
200               omp::stringifyClauseDefault(omp::ClauseDefault::defprivate)));
201           break;
202         case Fortran::parser::OmpDefaultClause::Type::Firstprivate:
203           parallelOp.default_valAttr(
204               firOpBuilder.getStringAttr(omp::stringifyClauseDefault(
205                   omp::ClauseDefault::deffirstprivate)));
206           break;
207         case Fortran::parser::OmpDefaultClause::Type::Shared:
208           parallelOp.default_valAttr(firOpBuilder.getStringAttr(
209               omp::stringifyClauseDefault(omp::ClauseDefault::defshared)));
210           break;
211         case Fortran::parser::OmpDefaultClause::Type::None:
212           parallelOp.default_valAttr(firOpBuilder.getStringAttr(
213               omp::stringifyClauseDefault(omp::ClauseDefault::defnone)));
214           break;
215         }
216       }
217       if (const auto &procBindClause =
218               std::get_if<Fortran::parser::OmpClause::ProcBind>(&clause.u)) {
219         const auto &ompProcBindClause{procBindClause->v};
220         switch (ompProcBindClause.v) {
221         case Fortran::parser::OmpProcBindClause::Type::Master:
222           parallelOp.proc_bind_valAttr(
223               firOpBuilder.getStringAttr(omp::stringifyClauseProcBindKind(
224                   omp::ClauseProcBindKind::master)));
225           break;
226         case Fortran::parser::OmpProcBindClause::Type::Close:
227           parallelOp.proc_bind_valAttr(
228               firOpBuilder.getStringAttr(omp::stringifyClauseProcBindKind(
229                   omp::ClauseProcBindKind::close)));
230           break;
231         case Fortran::parser::OmpProcBindClause::Type::Spread:
232           parallelOp.proc_bind_valAttr(
233               firOpBuilder.getStringAttr(omp::stringifyClauseProcBindKind(
234                   omp::ClauseProcBindKind::spread)));
235           break;
236         }
237       }
238     }
239     createBodyOfOp<omp::ParallelOp>(parallelOp, firOpBuilder, currentLocation);
240   } else if (blockDirective.v == llvm::omp::OMPD_master) {
241     auto masterOp =
242         firOpBuilder.create<mlir::omp::MasterOp>(currentLocation, argTy);
243     createBodyOfOp<omp::MasterOp>(masterOp, firOpBuilder, currentLocation);
244   }
245 }
246 
247 void Fortran::lower::genOpenMPConstruct(
248     Fortran::lower::AbstractConverter &converter,
249     Fortran::lower::pft::Evaluation &eval,
250     const Fortran::parser::OpenMPConstruct &ompConstruct) {
251 
252   std::visit(
253       common::visitors{
254           [&](const Fortran::parser::OpenMPStandaloneConstruct
255                   &standaloneConstruct) {
256             genOMP(converter, eval, standaloneConstruct);
257           },
258           [&](const Fortran::parser::OpenMPSectionsConstruct
259                   &sectionsConstruct) { TODO(""); },
260           [&](const Fortran::parser::OpenMPLoopConstruct &loopConstruct) {
261             TODO("");
262           },
263           [&](const Fortran::parser::OpenMPDeclarativeAllocate
264                   &execAllocConstruct) { TODO(""); },
265           [&](const Fortran::parser::OpenMPExecutableAllocate
266                   &execAllocConstruct) { TODO(""); },
267           [&](const Fortran::parser::OpenMPBlockConstruct &blockConstruct) {
268             genOMP(converter, eval, blockConstruct);
269           },
270           [&](const Fortran::parser::OpenMPAtomicConstruct &atomicConstruct) {
271             TODO("");
272           },
273           [&](const Fortran::parser::OpenMPCriticalConstruct
274                   &criticalConstruct) { TODO(""); },
275       },
276       ompConstruct.u);
277 }
278 
279 void Fortran::lower::genOpenMPEndLoop(
280     Fortran::lower::AbstractConverter &, Fortran::lower::pft::Evaluation &,
281     const Fortran::parser::OmpEndLoopDirective &) {
282   TODO("");
283 }
284