1 //===-- OpenMP.cpp -- Open MP directive lowering --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "flang/Lower/OpenMP.h"
14 #include "flang/Common/idioms.h"
15 #include "flang/Lower/Bridge.h"
16 #include "flang/Lower/PFTBuilder.h"
17 #include "flang/Lower/StatementContext.h"
18 #include "flang/Lower/Todo.h"
19 #include "flang/Optimizer/Builder/BoxValue.h"
20 #include "flang/Optimizer/Builder/FIRBuilder.h"
21 #include "flang/Parser/parse-tree.h"
22 #include "flang/Semantics/tools.h"
23 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
24 #include "llvm/Frontend/OpenMP/OMPConstants.h"
25 
26 using namespace mlir;
27 
28 static const Fortran::parser::Name *
29 getDesignatorNameIfDataRef(const Fortran::parser::Designator &designator) {
30   const auto *dataRef = std::get_if<Fortran::parser::DataRef>(&designator.u);
31   return dataRef ? std::get_if<Fortran::parser::Name>(&dataRef->u) : nullptr;
32 }
33 
34 static void genObjectList(const Fortran::parser::OmpObjectList &objectList,
35                           Fortran::lower::AbstractConverter &converter,
36                           SmallVectorImpl<Value> &operands) {
37   for (const auto &ompObject : objectList.v) {
38     std::visit(
39         Fortran::common::visitors{
40             [&](const Fortran::parser::Designator &designator) {
41               if (const auto *name = getDesignatorNameIfDataRef(designator)) {
42                 const auto variable = converter.getSymbolAddress(*name->symbol);
43                 operands.push_back(variable);
44               }
45             },
46             [&](const Fortran::parser::Name &name) {
47               const auto variable = converter.getSymbolAddress(*name.symbol);
48               operands.push_back(variable);
49             }},
50         ompObject.u);
51   }
52 }
53 
54 template <typename Op>
55 static void createBodyOfOp(Op &op, fir::FirOpBuilder &firOpBuilder,
56                            mlir::Location &loc) {
57   firOpBuilder.createBlock(&op.getRegion());
58   auto &block = op.getRegion().back();
59   firOpBuilder.setInsertionPointToStart(&block);
60   // Ensure the block is well-formed.
61   firOpBuilder.create<mlir::omp::TerminatorOp>(loc);
62   // Reset the insertion point to the start of the first block.
63   firOpBuilder.setInsertionPointToStart(&block);
64 }
65 
66 static void genOMP(Fortran::lower::AbstractConverter &converter,
67                    Fortran::lower::pft::Evaluation &eval,
68                    const Fortran::parser::OpenMPSimpleStandaloneConstruct
69                        &simpleStandaloneConstruct) {
70   const auto &directive =
71       std::get<Fortran::parser::OmpSimpleStandaloneDirective>(
72           simpleStandaloneConstruct.t);
73   switch (directive.v) {
74   default:
75     break;
76   case llvm::omp::Directive::OMPD_barrier:
77     converter.getFirOpBuilder().create<mlir::omp::BarrierOp>(
78         converter.getCurrentLocation());
79     break;
80   case llvm::omp::Directive::OMPD_taskwait:
81     converter.getFirOpBuilder().create<mlir::omp::TaskwaitOp>(
82         converter.getCurrentLocation());
83     break;
84   case llvm::omp::Directive::OMPD_taskyield:
85     converter.getFirOpBuilder().create<mlir::omp::TaskyieldOp>(
86         converter.getCurrentLocation());
87     break;
88   case llvm::omp::Directive::OMPD_target_enter_data:
89     TODO(converter.getCurrentLocation(), "OMPD_target_enter_data");
90   case llvm::omp::Directive::OMPD_target_exit_data:
91     TODO(converter.getCurrentLocation(), "OMPD_target_exit_data");
92   case llvm::omp::Directive::OMPD_target_update:
93     TODO(converter.getCurrentLocation(), "OMPD_target_update");
94   case llvm::omp::Directive::OMPD_ordered:
95     TODO(converter.getCurrentLocation(), "OMPD_ordered");
96   }
97 }
98 
99 static void
100 genOMP(Fortran::lower::AbstractConverter &converter,
101        Fortran::lower::pft::Evaluation &eval,
102        const Fortran::parser::OpenMPStandaloneConstruct &standaloneConstruct) {
103   std::visit(
104       Fortran::common::visitors{
105           [&](const Fortran::parser::OpenMPSimpleStandaloneConstruct
106                   &simpleStandaloneConstruct) {
107             genOMP(converter, eval, simpleStandaloneConstruct);
108           },
109           [&](const Fortran::parser::OpenMPFlushConstruct &flushConstruct) {
110             SmallVector<Value, 4> operandRange;
111             if (const auto &ompObjectList =
112                     std::get<std::optional<Fortran::parser::OmpObjectList>>(
113                         flushConstruct.t))
114               genObjectList(*ompObjectList, converter, operandRange);
115             if (std::get<std::optional<
116                     std::list<Fortran::parser::OmpMemoryOrderClause>>>(
117                     flushConstruct.t))
118               TODO(converter.getCurrentLocation(),
119                    "Handle OmpMemoryOrderClause");
120             converter.getFirOpBuilder().create<mlir::omp::FlushOp>(
121                 converter.getCurrentLocation(), operandRange);
122           },
123           [&](const Fortran::parser::OpenMPCancelConstruct &cancelConstruct) {
124             TODO(converter.getCurrentLocation(), "OpenMPCancelConstruct");
125           },
126           [&](const Fortran::parser::OpenMPCancellationPointConstruct
127                   &cancellationPointConstruct) {
128             TODO(converter.getCurrentLocation(), "OpenMPCancelConstruct");
129           },
130       },
131       standaloneConstruct.u);
132 }
133 
134 static void
135 genOMP(Fortran::lower::AbstractConverter &converter,
136        Fortran::lower::pft::Evaluation &eval,
137        const Fortran::parser::OpenMPBlockConstruct &blockConstruct) {
138   const auto &beginBlockDirective =
139       std::get<Fortran::parser::OmpBeginBlockDirective>(blockConstruct.t);
140   const auto &blockDirective =
141       std::get<Fortran::parser::OmpBlockDirective>(beginBlockDirective.t);
142 
143   auto &firOpBuilder = converter.getFirOpBuilder();
144   auto currentLocation = converter.getCurrentLocation();
145   Fortran::lower::StatementContext stmtCtx;
146   llvm::ArrayRef<mlir::Type> argTy;
147   if (blockDirective.v == llvm::omp::OMPD_parallel) {
148 
149     mlir::Value ifClauseOperand, numThreadsClauseOperand;
150     Attribute procBindClauseOperand;
151 
152     const auto &parallelOpClauseList =
153         std::get<Fortran::parser::OmpClauseList>(beginBlockDirective.t);
154     for (const auto &clause : parallelOpClauseList.v) {
155       if (const auto &ifClause =
156               std::get_if<Fortran::parser::OmpClause::If>(&clause.u)) {
157         auto &expr =
158             std::get<Fortran::parser::ScalarLogicalExpr>(ifClause->v.t);
159         ifClauseOperand = fir::getBase(converter.genExprValue(
160             *Fortran::semantics::GetExpr(expr), stmtCtx));
161       } else if (const auto &numThreadsClause =
162                      std::get_if<Fortran::parser::OmpClause::NumThreads>(
163                          &clause.u)) {
164         // OMPIRBuilder expects `NUM_THREAD` clause as a `Value`.
165         numThreadsClauseOperand = fir::getBase(converter.genExprValue(
166             *Fortran::semantics::GetExpr(numThreadsClause->v), stmtCtx));
167       }
168       // TODO: Handle private, firstprivate, shared and copyin
169     }
170     // Create and insert the operation.
171     auto parallelOp = firOpBuilder.create<mlir::omp::ParallelOp>(
172         currentLocation, argTy, ifClauseOperand, numThreadsClauseOperand,
173         ValueRange(), ValueRange(),
174         procBindClauseOperand.dyn_cast_or_null<omp::ClauseProcBindKindAttr>());
175     // Handle attribute based clauses.
176     for (const auto &clause : parallelOpClauseList.v) {
177       // TODO: Handle default clause
178       if (const auto &procBindClause =
179               std::get_if<Fortran::parser::OmpClause::ProcBind>(&clause.u)) {
180         const auto &ompProcBindClause{procBindClause->v};
181         omp::ClauseProcBindKind pbKind;
182         switch (ompProcBindClause.v) {
183         case Fortran::parser::OmpProcBindClause::Type::Master:
184           pbKind = omp::ClauseProcBindKind::Master;
185           break;
186         case Fortran::parser::OmpProcBindClause::Type::Close:
187           pbKind = omp::ClauseProcBindKind::Close;
188           break;
189         case Fortran::parser::OmpProcBindClause::Type::Spread:
190           pbKind = omp::ClauseProcBindKind::Spread;
191           break;
192         }
193         parallelOp.proc_bind_valAttr(omp::ClauseProcBindKindAttr::get(
194             firOpBuilder.getContext(), pbKind));
195       }
196     }
197     createBodyOfOp<omp::ParallelOp>(parallelOp, firOpBuilder, currentLocation);
198   } else if (blockDirective.v == llvm::omp::OMPD_master) {
199     auto masterOp =
200         firOpBuilder.create<mlir::omp::MasterOp>(currentLocation, argTy);
201     createBodyOfOp<omp::MasterOp>(masterOp, firOpBuilder, currentLocation);
202   }
203 }
204 
205 void Fortran::lower::genOpenMPConstruct(
206     Fortran::lower::AbstractConverter &converter,
207     Fortran::lower::pft::Evaluation &eval,
208     const Fortran::parser::OpenMPConstruct &ompConstruct) {
209 
210   std::visit(
211       common::visitors{
212           [&](const Fortran::parser::OpenMPStandaloneConstruct
213                   &standaloneConstruct) {
214             genOMP(converter, eval, standaloneConstruct);
215           },
216           [&](const Fortran::parser::OpenMPSectionsConstruct
217                   &sectionsConstruct) {
218             TODO(converter.getCurrentLocation(), "OpenMPSectionsConstruct");
219           },
220           [&](const Fortran::parser::OpenMPSectionConstruct &sectionConstruct) {
221             TODO(converter.getCurrentLocation(), "OpenMPSectionConstruct");
222           },
223           [&](const Fortran::parser::OpenMPLoopConstruct &loopConstruct) {
224             TODO(converter.getCurrentLocation(), "OpenMPLoopConstruct");
225           },
226           [&](const Fortran::parser::OpenMPDeclarativeAllocate
227                   &execAllocConstruct) {
228             TODO(converter.getCurrentLocation(), "OpenMPDeclarativeAllocate");
229           },
230           [&](const Fortran::parser::OpenMPExecutableAllocate
231                   &execAllocConstruct) {
232             TODO(converter.getCurrentLocation(), "OpenMPExecutableAllocate");
233           },
234           [&](const Fortran::parser::OpenMPBlockConstruct &blockConstruct) {
235             genOMP(converter, eval, blockConstruct);
236           },
237           [&](const Fortran::parser::OpenMPAtomicConstruct &atomicConstruct) {
238             TODO(converter.getCurrentLocation(), "OpenMPAtomicConstruct");
239           },
240           [&](const Fortran::parser::OpenMPCriticalConstruct
241                   &criticalConstruct) {
242             TODO(converter.getCurrentLocation(), "OpenMPCriticalConstruct");
243           },
244       },
245       ompConstruct.u);
246 }
247