1 //===-- OpenMP.cpp -- Open MP directive lowering --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "flang/Lower/OpenMP.h" 14 #include "flang/Common/idioms.h" 15 #include "flang/Lower/Bridge.h" 16 #include "flang/Lower/PFTBuilder.h" 17 #include "flang/Lower/StatementContext.h" 18 #include "flang/Lower/Todo.h" 19 #include "flang/Optimizer/Builder/BoxValue.h" 20 #include "flang/Optimizer/Builder/FIRBuilder.h" 21 #include "flang/Parser/parse-tree.h" 22 #include "flang/Semantics/tools.h" 23 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 24 #include "llvm/Frontend/OpenMP/OMPConstants.h" 25 26 using namespace mlir; 27 28 static const Fortran::parser::Name * 29 getDesignatorNameIfDataRef(const Fortran::parser::Designator &designator) { 30 const auto *dataRef = std::get_if<Fortran::parser::DataRef>(&designator.u); 31 return dataRef ? std::get_if<Fortran::parser::Name>(&dataRef->u) : nullptr; 32 } 33 34 static void genObjectList(const Fortran::parser::OmpObjectList &objectList, 35 Fortran::lower::AbstractConverter &converter, 36 SmallVectorImpl<Value> &operands) { 37 for (const auto &ompObject : objectList.v) { 38 std::visit( 39 Fortran::common::visitors{ 40 [&](const Fortran::parser::Designator &designator) { 41 if (const auto *name = getDesignatorNameIfDataRef(designator)) { 42 const auto variable = converter.getSymbolAddress(*name->symbol); 43 operands.push_back(variable); 44 } 45 }, 46 [&](const Fortran::parser::Name &name) { 47 const auto variable = converter.getSymbolAddress(*name.symbol); 48 operands.push_back(variable); 49 }}, 50 ompObject.u); 51 } 52 } 53 54 template <typename Op> 55 static void createBodyOfOp(Op &op, fir::FirOpBuilder &firOpBuilder, 56 mlir::Location &loc) { 57 firOpBuilder.createBlock(&op.getRegion()); 58 auto &block = op.getRegion().back(); 59 firOpBuilder.setInsertionPointToStart(&block); 60 // Ensure the block is well-formed. 61 firOpBuilder.create<mlir::omp::TerminatorOp>(loc); 62 // Reset the insertion point to the start of the first block. 63 firOpBuilder.setInsertionPointToStart(&block); 64 } 65 66 static void genOMP(Fortran::lower::AbstractConverter &converter, 67 Fortran::lower::pft::Evaluation &eval, 68 const Fortran::parser::OpenMPSimpleStandaloneConstruct 69 &simpleStandaloneConstruct) { 70 const auto &directive = 71 std::get<Fortran::parser::OmpSimpleStandaloneDirective>( 72 simpleStandaloneConstruct.t); 73 switch (directive.v) { 74 default: 75 break; 76 case llvm::omp::Directive::OMPD_barrier: 77 converter.getFirOpBuilder().create<mlir::omp::BarrierOp>( 78 converter.getCurrentLocation()); 79 break; 80 case llvm::omp::Directive::OMPD_taskwait: 81 converter.getFirOpBuilder().create<mlir::omp::TaskwaitOp>( 82 converter.getCurrentLocation()); 83 break; 84 case llvm::omp::Directive::OMPD_taskyield: 85 converter.getFirOpBuilder().create<mlir::omp::TaskyieldOp>( 86 converter.getCurrentLocation()); 87 break; 88 case llvm::omp::Directive::OMPD_target_enter_data: 89 TODO(converter.getCurrentLocation(), "OMPD_target_enter_data"); 90 case llvm::omp::Directive::OMPD_target_exit_data: 91 TODO(converter.getCurrentLocation(), "OMPD_target_exit_data"); 92 case llvm::omp::Directive::OMPD_target_update: 93 TODO(converter.getCurrentLocation(), "OMPD_target_update"); 94 case llvm::omp::Directive::OMPD_ordered: 95 TODO(converter.getCurrentLocation(), "OMPD_ordered"); 96 } 97 } 98 99 static void 100 genOMP(Fortran::lower::AbstractConverter &converter, 101 Fortran::lower::pft::Evaluation &eval, 102 const Fortran::parser::OpenMPStandaloneConstruct &standaloneConstruct) { 103 std::visit( 104 Fortran::common::visitors{ 105 [&](const Fortran::parser::OpenMPSimpleStandaloneConstruct 106 &simpleStandaloneConstruct) { 107 genOMP(converter, eval, simpleStandaloneConstruct); 108 }, 109 [&](const Fortran::parser::OpenMPFlushConstruct &flushConstruct) { 110 SmallVector<Value, 4> operandRange; 111 if (const auto &ompObjectList = 112 std::get<std::optional<Fortran::parser::OmpObjectList>>( 113 flushConstruct.t)) 114 genObjectList(*ompObjectList, converter, operandRange); 115 const auto &memOrderClause = std::get<std::optional< 116 std::list<Fortran::parser::OmpMemoryOrderClause>>>( 117 flushConstruct.t); 118 if (memOrderClause.has_value() && memOrderClause->size() > 0) 119 TODO(converter.getCurrentLocation(), 120 "Handle OmpMemoryOrderClause"); 121 converter.getFirOpBuilder().create<mlir::omp::FlushOp>( 122 converter.getCurrentLocation(), operandRange); 123 }, 124 [&](const Fortran::parser::OpenMPCancelConstruct &cancelConstruct) { 125 TODO(converter.getCurrentLocation(), "OpenMPCancelConstruct"); 126 }, 127 [&](const Fortran::parser::OpenMPCancellationPointConstruct 128 &cancellationPointConstruct) { 129 TODO(converter.getCurrentLocation(), "OpenMPCancelConstruct"); 130 }, 131 }, 132 standaloneConstruct.u); 133 } 134 135 static void 136 genOMP(Fortran::lower::AbstractConverter &converter, 137 Fortran::lower::pft::Evaluation &eval, 138 const Fortran::parser::OpenMPBlockConstruct &blockConstruct) { 139 const auto &beginBlockDirective = 140 std::get<Fortran::parser::OmpBeginBlockDirective>(blockConstruct.t); 141 const auto &blockDirective = 142 std::get<Fortran::parser::OmpBlockDirective>(beginBlockDirective.t); 143 144 auto &firOpBuilder = converter.getFirOpBuilder(); 145 auto currentLocation = converter.getCurrentLocation(); 146 Fortran::lower::StatementContext stmtCtx; 147 llvm::ArrayRef<mlir::Type> argTy; 148 if (blockDirective.v == llvm::omp::OMPD_parallel) { 149 150 mlir::Value ifClauseOperand, numThreadsClauseOperand; 151 Attribute procBindClauseOperand; 152 153 const auto ¶llelOpClauseList = 154 std::get<Fortran::parser::OmpClauseList>(beginBlockDirective.t); 155 for (const auto &clause : parallelOpClauseList.v) { 156 if (const auto &ifClause = 157 std::get_if<Fortran::parser::OmpClause::If>(&clause.u)) { 158 auto &expr = 159 std::get<Fortran::parser::ScalarLogicalExpr>(ifClause->v.t); 160 ifClauseOperand = fir::getBase(converter.genExprValue( 161 *Fortran::semantics::GetExpr(expr), stmtCtx)); 162 } else if (const auto &numThreadsClause = 163 std::get_if<Fortran::parser::OmpClause::NumThreads>( 164 &clause.u)) { 165 // OMPIRBuilder expects `NUM_THREAD` clause as a `Value`. 166 numThreadsClauseOperand = fir::getBase(converter.genExprValue( 167 *Fortran::semantics::GetExpr(numThreadsClause->v), stmtCtx)); 168 } 169 // TODO: Handle private, firstprivate, shared and copyin 170 } 171 // Create and insert the operation. 172 auto parallelOp = firOpBuilder.create<mlir::omp::ParallelOp>( 173 currentLocation, argTy, ifClauseOperand, numThreadsClauseOperand, 174 ValueRange(), ValueRange(), 175 procBindClauseOperand.dyn_cast_or_null<omp::ClauseProcBindKindAttr>()); 176 // Handle attribute based clauses. 177 for (const auto &clause : parallelOpClauseList.v) { 178 // TODO: Handle default clause 179 if (const auto &procBindClause = 180 std::get_if<Fortran::parser::OmpClause::ProcBind>(&clause.u)) { 181 const auto &ompProcBindClause{procBindClause->v}; 182 omp::ClauseProcBindKind pbKind; 183 switch (ompProcBindClause.v) { 184 case Fortran::parser::OmpProcBindClause::Type::Master: 185 pbKind = omp::ClauseProcBindKind::Master; 186 break; 187 case Fortran::parser::OmpProcBindClause::Type::Close: 188 pbKind = omp::ClauseProcBindKind::Close; 189 break; 190 case Fortran::parser::OmpProcBindClause::Type::Spread: 191 pbKind = omp::ClauseProcBindKind::Spread; 192 break; 193 } 194 parallelOp.proc_bind_valAttr(omp::ClauseProcBindKindAttr::get( 195 firOpBuilder.getContext(), pbKind)); 196 } 197 } 198 createBodyOfOp<omp::ParallelOp>(parallelOp, firOpBuilder, currentLocation); 199 } else if (blockDirective.v == llvm::omp::OMPD_master) { 200 auto masterOp = 201 firOpBuilder.create<mlir::omp::MasterOp>(currentLocation, argTy); 202 createBodyOfOp<omp::MasterOp>(masterOp, firOpBuilder, currentLocation); 203 } 204 } 205 206 static void 207 genOMP(Fortran::lower::AbstractConverter &converter, 208 Fortran::lower::pft::Evaluation &eval, 209 const Fortran::parser::OpenMPCriticalConstruct &criticalConstruct) { 210 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 211 mlir::Location currentLocation = converter.getCurrentLocation(); 212 std::string name; 213 const Fortran::parser::OmpCriticalDirective &cd = 214 std::get<Fortran::parser::OmpCriticalDirective>(criticalConstruct.t); 215 if (std::get<std::optional<Fortran::parser::Name>>(cd.t).has_value()) { 216 name = 217 std::get<std::optional<Fortran::parser::Name>>(cd.t).value().ToString(); 218 } 219 220 uint64_t hint = 0; 221 const auto &clauseList = std::get<Fortran::parser::OmpClauseList>(cd.t); 222 for (const Fortran::parser::OmpClause &clause : clauseList.v) 223 if (auto hintClause = 224 std::get_if<Fortran::parser::OmpClause::Hint>(&clause.u)) { 225 const auto *expr = Fortran::semantics::GetExpr(hintClause->v); 226 hint = *Fortran::evaluate::ToInt64(*expr); 227 break; 228 } 229 230 mlir::omp::CriticalOp criticalOp = [&]() { 231 if (name.empty()) { 232 return firOpBuilder.create<mlir::omp::CriticalOp>(currentLocation, 233 FlatSymbolRefAttr()); 234 } else { 235 mlir::ModuleOp module = firOpBuilder.getModule(); 236 mlir::OpBuilder modBuilder(module.getBodyRegion()); 237 auto global = module.lookupSymbol<mlir::omp::CriticalDeclareOp>(name); 238 if (!global) 239 global = modBuilder.create<mlir::omp::CriticalDeclareOp>( 240 currentLocation, name, hint); 241 return firOpBuilder.create<mlir::omp::CriticalOp>( 242 currentLocation, mlir::FlatSymbolRefAttr::get( 243 firOpBuilder.getContext(), global.sym_name())); 244 } 245 }(); 246 createBodyOfOp<omp::CriticalOp>(criticalOp, firOpBuilder, currentLocation); 247 } 248 249 void Fortran::lower::genOpenMPConstruct( 250 Fortran::lower::AbstractConverter &converter, 251 Fortran::lower::pft::Evaluation &eval, 252 const Fortran::parser::OpenMPConstruct &ompConstruct) { 253 254 std::visit( 255 common::visitors{ 256 [&](const Fortran::parser::OpenMPStandaloneConstruct 257 &standaloneConstruct) { 258 genOMP(converter, eval, standaloneConstruct); 259 }, 260 [&](const Fortran::parser::OpenMPSectionsConstruct 261 §ionsConstruct) { 262 TODO(converter.getCurrentLocation(), "OpenMPSectionsConstruct"); 263 }, 264 [&](const Fortran::parser::OpenMPSectionConstruct §ionConstruct) { 265 TODO(converter.getCurrentLocation(), "OpenMPSectionConstruct"); 266 }, 267 [&](const Fortran::parser::OpenMPLoopConstruct &loopConstruct) { 268 TODO(converter.getCurrentLocation(), "OpenMPLoopConstruct"); 269 }, 270 [&](const Fortran::parser::OpenMPDeclarativeAllocate 271 &execAllocConstruct) { 272 TODO(converter.getCurrentLocation(), "OpenMPDeclarativeAllocate"); 273 }, 274 [&](const Fortran::parser::OpenMPExecutableAllocate 275 &execAllocConstruct) { 276 TODO(converter.getCurrentLocation(), "OpenMPExecutableAllocate"); 277 }, 278 [&](const Fortran::parser::OpenMPBlockConstruct &blockConstruct) { 279 genOMP(converter, eval, blockConstruct); 280 }, 281 [&](const Fortran::parser::OpenMPAtomicConstruct &atomicConstruct) { 282 TODO(converter.getCurrentLocation(), "OpenMPAtomicConstruct"); 283 }, 284 [&](const Fortran::parser::OpenMPCriticalConstruct 285 &criticalConstruct) { 286 genOMP(converter, eval, criticalConstruct); 287 }, 288 }, 289 ompConstruct.u); 290 } 291