1 //===-- OpenMP.cpp -- Open MP directive lowering --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "flang/Lower/OpenMP.h"
14 #include "flang/Common/idioms.h"
15 #include "flang/Lower/Bridge.h"
16 #include "flang/Lower/PFTBuilder.h"
17 #include "flang/Lower/Todo.h"
18 #include "flang/Optimizer/Builder/BoxValue.h"
19 #include "flang/Optimizer/Builder/FIRBuilder.h"
20 #include "flang/Parser/parse-tree.h"
21 #include "flang/Semantics/tools.h"
22 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
23 #include "llvm/Frontend/OpenMP/OMPConstants.h"
24 
25 static const Fortran::parser::Name *
26 getDesignatorNameIfDataRef(const Fortran::parser::Designator &designator) {
27   const auto *dataRef = std::get_if<Fortran::parser::DataRef>(&designator.u);
28   return dataRef ? std::get_if<Fortran::parser::Name>(&dataRef->u) : nullptr;
29 }
30 
31 static void genObjectList(const Fortran::parser::OmpObjectList &objectList,
32                           Fortran::lower::AbstractConverter &converter,
33                           SmallVectorImpl<Value> &operands) {
34   for (const auto &ompObject : objectList.v) {
35     std::visit(
36         Fortran::common::visitors{
37             [&](const Fortran::parser::Designator &designator) {
38               if (const auto *name = getDesignatorNameIfDataRef(designator)) {
39                 const auto variable = converter.getSymbolAddress(*name->symbol);
40                 operands.push_back(variable);
41               }
42             },
43             [&](const Fortran::parser::Name &name) {
44               const auto variable = converter.getSymbolAddress(*name.symbol);
45               operands.push_back(variable);
46             }},
47         ompObject.u);
48   }
49 }
50 
51 template <typename Op>
52 static void createBodyOfOp(Op &op, fir::FirOpBuilder &firOpBuilder,
53                            mlir::Location &loc) {
54   firOpBuilder.createBlock(&op.getRegion());
55   auto &block = op.getRegion().back();
56   firOpBuilder.setInsertionPointToStart(&block);
57   // Ensure the block is well-formed.
58   firOpBuilder.create<mlir::omp::TerminatorOp>(loc);
59   // Reset the insertion point to the start of the first block.
60   firOpBuilder.setInsertionPointToStart(&block);
61 }
62 
63 static void genOMP(Fortran::lower::AbstractConverter &converter,
64                    Fortran::lower::pft::Evaluation &eval,
65                    const Fortran::parser::OpenMPSimpleStandaloneConstruct
66                        &simpleStandaloneConstruct) {
67   const auto &directive =
68       std::get<Fortran::parser::OmpSimpleStandaloneDirective>(
69           simpleStandaloneConstruct.t);
70   switch (directive.v) {
71   default:
72     break;
73   case llvm::omp::Directive::OMPD_barrier:
74     converter.getFirOpBuilder().create<mlir::omp::BarrierOp>(
75         converter.getCurrentLocation());
76     break;
77   case llvm::omp::Directive::OMPD_taskwait:
78     converter.getFirOpBuilder().create<mlir::omp::TaskwaitOp>(
79         converter.getCurrentLocation());
80     break;
81   case llvm::omp::Directive::OMPD_taskyield:
82     converter.getFirOpBuilder().create<mlir::omp::TaskyieldOp>(
83         converter.getCurrentLocation());
84     break;
85   case llvm::omp::Directive::OMPD_target_enter_data:
86     TODO(converter.getCurrentLocation(), "OMPD_target_enter_data");
87   case llvm::omp::Directive::OMPD_target_exit_data:
88     TODO(converter.getCurrentLocation(), "OMPD_target_exit_data");
89   case llvm::omp::Directive::OMPD_target_update:
90     TODO(converter.getCurrentLocation(), "OMPD_target_update");
91   case llvm::omp::Directive::OMPD_ordered:
92     TODO(converter.getCurrentLocation(), "OMPD_ordered");
93   }
94 }
95 
96 static void
97 genOMP(Fortran::lower::AbstractConverter &converter,
98        Fortran::lower::pft::Evaluation &eval,
99        const Fortran::parser::OpenMPStandaloneConstruct &standaloneConstruct) {
100   std::visit(
101       Fortran::common::visitors{
102           [&](const Fortran::parser::OpenMPSimpleStandaloneConstruct
103                   &simpleStandaloneConstruct) {
104             genOMP(converter, eval, simpleStandaloneConstruct);
105           },
106           [&](const Fortran::parser::OpenMPFlushConstruct &flushConstruct) {
107             SmallVector<Value, 4> operandRange;
108             if (const auto &ompObjectList =
109                     std::get<std::optional<Fortran::parser::OmpObjectList>>(
110                         flushConstruct.t))
111               genObjectList(*ompObjectList, converter, operandRange);
112             if (std::get<std::optional<
113                     std::list<Fortran::parser::OmpMemoryOrderClause>>>(
114                     flushConstruct.t))
115               TODO(converter.getCurrentLocation(),
116                    "Handle OmpMemoryOrderClause");
117             converter.getFirOpBuilder().create<mlir::omp::FlushOp>(
118                 converter.getCurrentLocation(), operandRange);
119           },
120           [&](const Fortran::parser::OpenMPCancelConstruct &cancelConstruct) {
121             TODO(converter.getCurrentLocation(), "OpenMPCancelConstruct");
122           },
123           [&](const Fortran::parser::OpenMPCancellationPointConstruct
124                   &cancellationPointConstruct) {
125             TODO(converter.getCurrentLocation(), "OpenMPCancelConstruct");
126           },
127       },
128       standaloneConstruct.u);
129 }
130 
131 static void
132 genOMP(Fortran::lower::AbstractConverter &converter,
133        Fortran::lower::pft::Evaluation &eval,
134        const Fortran::parser::OpenMPBlockConstruct &blockConstruct) {
135   const auto &beginBlockDirective =
136       std::get<Fortran::parser::OmpBeginBlockDirective>(blockConstruct.t);
137   const auto &blockDirective =
138       std::get<Fortran::parser::OmpBlockDirective>(beginBlockDirective.t);
139 
140   auto &firOpBuilder = converter.getFirOpBuilder();
141   auto currentLocation = converter.getCurrentLocation();
142   llvm::ArrayRef<mlir::Type> argTy;
143   if (blockDirective.v == llvm::omp::OMPD_parallel) {
144 
145     mlir::Value ifClauseOperand, numThreadsClauseOperand;
146     Attribute procBindClauseOperand;
147 
148     const auto &parallelOpClauseList =
149         std::get<Fortran::parser::OmpClauseList>(beginBlockDirective.t);
150     for (const auto &clause : parallelOpClauseList.v) {
151       if (const auto &ifClause =
152               std::get_if<Fortran::parser::OmpClause::If>(&clause.u)) {
153         auto &expr =
154             std::get<Fortran::parser::ScalarLogicalExpr>(ifClause->v.t);
155         ifClauseOperand = fir::getBase(
156             converter.genExprValue(*Fortran::semantics::GetExpr(expr)));
157       } else if (const auto &numThreadsClause =
158                      std::get_if<Fortran::parser::OmpClause::NumThreads>(
159                          &clause.u)) {
160         // OMPIRBuilder expects `NUM_THREAD` clause as a `Value`.
161         numThreadsClauseOperand = fir::getBase(converter.genExprValue(
162             *Fortran::semantics::GetExpr(numThreadsClause->v)));
163       }
164       // TODO: Handle private, firstprivate, shared and copyin
165     }
166     // Create and insert the operation.
167     auto parallelOp = firOpBuilder.create<mlir::omp::ParallelOp>(
168         currentLocation, argTy, ifClauseOperand, numThreadsClauseOperand,
169         ValueRange(), ValueRange(),
170         procBindClauseOperand.dyn_cast_or_null<omp::ClauseProcBindKindAttr>());
171     // Handle attribute based clauses.
172     for (const auto &clause : parallelOpClauseList.v) {
173       // TODO: Handle default clause
174       if (const auto &procBindClause =
175               std::get_if<Fortran::parser::OmpClause::ProcBind>(&clause.u)) {
176         const auto &ompProcBindClause{procBindClause->v};
177         omp::ClauseProcBindKind pbKind;
178         switch (ompProcBindClause.v) {
179         case Fortran::parser::OmpProcBindClause::Type::Master:
180           pbKind = omp::ClauseProcBindKind::master;
181           break;
182         case Fortran::parser::OmpProcBindClause::Type::Close:
183           pbKind = omp::ClauseProcBindKind::close;
184           break;
185         case Fortran::parser::OmpProcBindClause::Type::Spread:
186           pbKind = omp::ClauseProcBindKind::spread;
187           break;
188         }
189         parallelOp.proc_bind_valAttr(omp::ClauseProcBindKindAttr::get(
190             firOpBuilder.getContext(), pbKind));
191       }
192     }
193     createBodyOfOp<omp::ParallelOp>(parallelOp, firOpBuilder, currentLocation);
194   } else if (blockDirective.v == llvm::omp::OMPD_master) {
195     auto masterOp =
196         firOpBuilder.create<mlir::omp::MasterOp>(currentLocation, argTy);
197     createBodyOfOp<omp::MasterOp>(masterOp, firOpBuilder, currentLocation);
198   }
199 }
200 
201 void Fortran::lower::genOpenMPConstruct(
202     Fortran::lower::AbstractConverter &converter,
203     Fortran::lower::pft::Evaluation &eval,
204     const Fortran::parser::OpenMPConstruct &ompConstruct) {
205 
206   std::visit(
207       common::visitors{
208           [&](const Fortran::parser::OpenMPStandaloneConstruct
209                   &standaloneConstruct) {
210             genOMP(converter, eval, standaloneConstruct);
211           },
212           [&](const Fortran::parser::OpenMPSectionsConstruct
213                   &sectionsConstruct) {
214             TODO(converter.getCurrentLocation(), "OpenMPSectionsConstruct");
215           },
216           [&](const Fortran::parser::OpenMPLoopConstruct &loopConstruct) {
217             TODO(converter.getCurrentLocation(), "OpenMPLoopConstruct");
218           },
219           [&](const Fortran::parser::OpenMPDeclarativeAllocate
220                   &execAllocConstruct) {
221             TODO(converter.getCurrentLocation(), "OpenMPDeclarativeAllocate");
222           },
223           [&](const Fortran::parser::OpenMPExecutableAllocate
224                   &execAllocConstruct) {
225             TODO(converter.getCurrentLocation(), "OpenMPExecutableAllocate");
226           },
227           [&](const Fortran::parser::OpenMPBlockConstruct &blockConstruct) {
228             genOMP(converter, eval, blockConstruct);
229           },
230           [&](const Fortran::parser::OpenMPAtomicConstruct &atomicConstruct) {
231             TODO(converter.getCurrentLocation(), "OpenMPAtomicConstruct");
232           },
233           [&](const Fortran::parser::OpenMPCriticalConstruct
234                   &criticalConstruct) {
235             TODO(converter.getCurrentLocation(), "OpenMPCriticalConstruct");
236           },
237       },
238       ompConstruct.u);
239 }
240