1 //===-- OpenMP.cpp -- Open MP directive lowering --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "flang/Lower/OpenMP.h" 14 #include "flang/Common/idioms.h" 15 #include "flang/Lower/Bridge.h" 16 #include "flang/Lower/ConvertExpr.h" 17 #include "flang/Lower/PFTBuilder.h" 18 #include "flang/Lower/StatementContext.h" 19 #include "flang/Optimizer/Builder/BoxValue.h" 20 #include "flang/Optimizer/Builder/FIRBuilder.h" 21 #include "flang/Optimizer/Builder/Todo.h" 22 #include "flang/Parser/parse-tree.h" 23 #include "flang/Semantics/tools.h" 24 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 25 #include "llvm/Frontend/OpenMP/OMPConstants.h" 26 27 using namespace mlir; 28 29 int64_t Fortran::lower::getCollapseValue( 30 const Fortran::parser::OmpClauseList &clauseList) { 31 for (const auto &clause : clauseList.v) { 32 if (const auto &collapseClause = 33 std::get_if<Fortran::parser::OmpClause::Collapse>(&clause.u)) { 34 const auto *expr = Fortran::semantics::GetExpr(collapseClause->v); 35 return Fortran::evaluate::ToInt64(*expr).value(); 36 } 37 } 38 return 1; 39 } 40 41 static const Fortran::parser::Name * 42 getDesignatorNameIfDataRef(const Fortran::parser::Designator &designator) { 43 const auto *dataRef = std::get_if<Fortran::parser::DataRef>(&designator.u); 44 return dataRef ? std::get_if<Fortran::parser::Name>(&dataRef->u) : nullptr; 45 } 46 47 static Fortran::semantics::Symbol * 48 getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject) { 49 Fortran::semantics::Symbol *sym = nullptr; 50 std::visit(Fortran::common::visitors{ 51 [&](const Fortran::parser::Designator &designator) { 52 if (const Fortran::parser::Name *name = 53 getDesignatorNameIfDataRef(designator)) { 54 sym = name->symbol; 55 } 56 }, 57 [&](const Fortran::parser::Name &name) { sym = name.symbol; }}, 58 ompObject.u); 59 return sym; 60 } 61 62 template <typename T> 63 static void createPrivateVarSyms(Fortran::lower::AbstractConverter &converter, 64 const T *clause) { 65 const Fortran::parser::OmpObjectList &ompObjectList = clause->v; 66 for (const Fortran::parser::OmpObject &ompObject : ompObjectList.v) { 67 Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); 68 // Privatization for symbols which are pre-determined (like loop index 69 // variables) happen separately, for everything else privatize here. 70 if (sym->test(Fortran::semantics::Symbol::Flag::OmpPreDetermined)) 71 continue; 72 bool success = converter.createHostAssociateVarClone(*sym); 73 (void)success; 74 assert(success && "Privatization failed due to existing binding"); 75 if constexpr (std::is_same_v<T, Fortran::parser::OmpClause::Firstprivate>) { 76 converter.copyHostAssociateVar(*sym); 77 } 78 } 79 } 80 81 static void privatizeVars(Fortran::lower::AbstractConverter &converter, 82 const Fortran::parser::OmpClauseList &opClauseList) { 83 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 84 auto insPt = firOpBuilder.saveInsertionPoint(); 85 firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock()); 86 bool hasFirstPrivateOp = false; 87 for (const Fortran::parser::OmpClause &clause : opClauseList.v) { 88 if (const auto &privateClause = 89 std::get_if<Fortran::parser::OmpClause::Private>(&clause.u)) { 90 createPrivateVarSyms(converter, privateClause); 91 } else if (const auto &firstPrivateClause = 92 std::get_if<Fortran::parser::OmpClause::Firstprivate>( 93 &clause.u)) { 94 createPrivateVarSyms(converter, firstPrivateClause); 95 hasFirstPrivateOp = true; 96 } 97 } 98 if (hasFirstPrivateOp) 99 firOpBuilder.create<mlir::omp::BarrierOp>(converter.getCurrentLocation()); 100 firOpBuilder.restoreInsertionPoint(insPt); 101 } 102 103 /// The COMMON block is a global structure. \p commonValue is the base address 104 /// of the the COMMON block. As the offset from the symbol \p sym, generate the 105 /// COMMON block member value (commonValue + offset) for the symbol. 106 /// FIXME: Share the code with `instantiateCommon` in ConvertVariable.cpp. 107 static mlir::Value 108 genCommonBlockMember(Fortran::lower::AbstractConverter &converter, 109 const Fortran::semantics::Symbol &sym, 110 mlir::Value commonValue) { 111 auto &firOpBuilder = converter.getFirOpBuilder(); 112 mlir::Location currentLocation = converter.getCurrentLocation(); 113 mlir::IntegerType i8Ty = firOpBuilder.getIntegerType(8); 114 mlir::Type i8Ptr = firOpBuilder.getRefType(i8Ty); 115 mlir::Type seqTy = firOpBuilder.getRefType(firOpBuilder.getVarLenSeqTy(i8Ty)); 116 mlir::Value base = 117 firOpBuilder.createConvert(currentLocation, seqTy, commonValue); 118 std::size_t byteOffset = sym.GetUltimate().offset(); 119 mlir::Value offs = firOpBuilder.createIntegerConstant( 120 currentLocation, firOpBuilder.getIndexType(), byteOffset); 121 mlir::Value varAddr = firOpBuilder.create<fir::CoordinateOp>( 122 currentLocation, i8Ptr, base, mlir::ValueRange{offs}); 123 mlir::Type symType = converter.genType(sym); 124 return firOpBuilder.createConvert(currentLocation, 125 firOpBuilder.getRefType(symType), varAddr); 126 } 127 128 // Get the extended value for \p val by extracting additional variable 129 // information from \p base. 130 static fir::ExtendedValue getExtendedValue(fir::ExtendedValue base, 131 mlir::Value val) { 132 return base.match( 133 [&](const fir::MutableBoxValue &box) -> fir::ExtendedValue { 134 return fir::MutableBoxValue(val, box.nonDeferredLenParams(), {}); 135 }, 136 [&](const auto &) -> fir::ExtendedValue { 137 return fir::substBase(base, val); 138 }); 139 } 140 141 static void threadPrivatizeVars(Fortran::lower::AbstractConverter &converter, 142 Fortran::lower::pft::Evaluation &eval) { 143 auto &firOpBuilder = converter.getFirOpBuilder(); 144 mlir::Location currentLocation = converter.getCurrentLocation(); 145 auto insPt = firOpBuilder.saveInsertionPoint(); 146 firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock()); 147 148 // Get the original ThreadprivateOp corresponding to the symbol and use the 149 // symbol value from that opeartion to create one ThreadprivateOp copy 150 // operation inside the parallel region. 151 auto genThreadprivateOp = [&](Fortran::lower::SymbolRef sym) -> mlir::Value { 152 mlir::Value symOriThreadprivateValue = converter.getSymbolAddress(sym); 153 mlir::Operation *op = symOriThreadprivateValue.getDefiningOp(); 154 assert(mlir::isa<mlir::omp::ThreadprivateOp>(op) && 155 "The threadprivate operation not created"); 156 mlir::Value symValue = 157 mlir::dyn_cast<mlir::omp::ThreadprivateOp>(op).sym_addr(); 158 return firOpBuilder.create<mlir::omp::ThreadprivateOp>( 159 currentLocation, symValue.getType(), symValue); 160 }; 161 162 llvm::SetVector<const Fortran::semantics::Symbol *> threadprivateSyms; 163 converter.collectSymbolSet(eval, threadprivateSyms, 164 Fortran::semantics::Symbol::Flag::OmpThreadprivate, 165 /*isUltimateSymbol=*/false); 166 std::set<Fortran::semantics::SourceName> threadprivateSymNames; 167 168 // For a COMMON block, the ThreadprivateOp is generated for itself instead of 169 // its members, so only bind the value of the new copied ThreadprivateOp 170 // inside the parallel region to the common block symbol only once for 171 // multiple members in one COMMON block. 172 llvm::SetVector<const Fortran::semantics::Symbol *> commonSyms; 173 for (std::size_t i = 0; i < threadprivateSyms.size(); i++) { 174 auto sym = threadprivateSyms[i]; 175 mlir::Value symThreadprivateValue; 176 // The variable may be used more than once, and each reference has one 177 // symbol with the same name. Only do once for references of one variable. 178 if (threadprivateSymNames.find(sym->name()) != threadprivateSymNames.end()) 179 continue; 180 threadprivateSymNames.insert(sym->name()); 181 if (const Fortran::semantics::Symbol *common = 182 Fortran::semantics::FindCommonBlockContaining(sym->GetUltimate())) { 183 mlir::Value commonThreadprivateValue; 184 if (commonSyms.contains(common)) { 185 commonThreadprivateValue = converter.getSymbolAddress(*common); 186 } else { 187 commonThreadprivateValue = genThreadprivateOp(*common); 188 converter.bindSymbol(*common, commonThreadprivateValue); 189 commonSyms.insert(common); 190 } 191 symThreadprivateValue = 192 genCommonBlockMember(converter, *sym, commonThreadprivateValue); 193 } else { 194 symThreadprivateValue = genThreadprivateOp(*sym); 195 } 196 197 fir::ExtendedValue sexv = converter.getSymbolExtendedValue(*sym); 198 fir::ExtendedValue symThreadprivateExv = 199 getExtendedValue(sexv, symThreadprivateValue); 200 converter.bindSymbol(*sym, symThreadprivateExv); 201 } 202 203 firOpBuilder.restoreInsertionPoint(insPt); 204 } 205 206 static void 207 genCopyinClause(Fortran::lower::AbstractConverter &converter, 208 const Fortran::parser::OmpClauseList &opClauseList) { 209 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 210 mlir::OpBuilder::InsertPoint insPt = firOpBuilder.saveInsertionPoint(); 211 firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock()); 212 bool hasCopyin = false; 213 for (const Fortran::parser::OmpClause &clause : opClauseList.v) { 214 if (const auto ©inClause = 215 std::get_if<Fortran::parser::OmpClause::Copyin>(&clause.u)) { 216 hasCopyin = true; 217 const Fortran::parser::OmpObjectList &ompObjectList = copyinClause->v; 218 for (const Fortran::parser::OmpObject &ompObject : ompObjectList.v) { 219 Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); 220 if (sym->has<Fortran::semantics::CommonBlockDetails>()) 221 TODO(converter.getCurrentLocation(), "common block in Copyin clause"); 222 if (Fortran::semantics::IsAllocatableOrPointer(sym->GetUltimate())) 223 TODO(converter.getCurrentLocation(), 224 "pointer or allocatable variables in Copyin clause"); 225 assert(sym->has<Fortran::semantics::HostAssocDetails>() && 226 "No host-association found"); 227 converter.copyHostAssociateVar(*sym); 228 } 229 } 230 } 231 // [OMP 5.0, 2.19.6.1] The copy is done after the team is formed and prior to 232 // the execution of the associated structured block. Emit implicit barrier to 233 // synchronize threads and avoid data races on propagation master's thread 234 // values of threadprivate variables to local instances of that variables of 235 // all other implicit threads. 236 if (hasCopyin) 237 firOpBuilder.create<mlir::omp::BarrierOp>(converter.getCurrentLocation()); 238 firOpBuilder.restoreInsertionPoint(insPt); 239 } 240 241 static void genObjectList(const Fortran::parser::OmpObjectList &objectList, 242 Fortran::lower::AbstractConverter &converter, 243 llvm::SmallVectorImpl<Value> &operands) { 244 auto addOperands = [&](Fortran::lower::SymbolRef sym) { 245 const mlir::Value variable = converter.getSymbolAddress(sym); 246 if (variable) { 247 operands.push_back(variable); 248 } else { 249 if (const auto *details = 250 sym->detailsIf<Fortran::semantics::HostAssocDetails>()) { 251 operands.push_back(converter.getSymbolAddress(details->symbol())); 252 converter.copySymbolBinding(details->symbol(), sym); 253 } 254 } 255 }; 256 for (const Fortran::parser::OmpObject &ompObject : objectList.v) { 257 Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); 258 addOperands(*sym); 259 } 260 } 261 262 static mlir::Type getLoopVarType(Fortran::lower::AbstractConverter &converter, 263 std::size_t loopVarTypeSize) { 264 // OpenMP runtime requires 32-bit or 64-bit loop variables. 265 loopVarTypeSize = loopVarTypeSize * 8; 266 if (loopVarTypeSize < 32) { 267 loopVarTypeSize = 32; 268 } else if (loopVarTypeSize > 64) { 269 loopVarTypeSize = 64; 270 mlir::emitWarning(converter.getCurrentLocation(), 271 "OpenMP loop iteration variable cannot have more than 64 " 272 "bits size and will be narrowed into 64 bits."); 273 } 274 assert((loopVarTypeSize == 32 || loopVarTypeSize == 64) && 275 "OpenMP loop iteration variable size must be transformed into 32-bit " 276 "or 64-bit"); 277 return converter.getFirOpBuilder().getIntegerType(loopVarTypeSize); 278 } 279 280 /// Create empty blocks for the current region. 281 /// These blocks replace blocks parented to an enclosing region. 282 void createEmptyRegionBlocks( 283 fir::FirOpBuilder &firOpBuilder, 284 std::list<Fortran::lower::pft::Evaluation> &evaluationList) { 285 auto *region = &firOpBuilder.getRegion(); 286 for (auto &eval : evaluationList) { 287 if (eval.block) { 288 if (eval.block->empty()) { 289 eval.block->erase(); 290 eval.block = firOpBuilder.createBlock(region); 291 } else { 292 [[maybe_unused]] auto &terminatorOp = eval.block->back(); 293 assert((mlir::isa<mlir::omp::TerminatorOp>(terminatorOp) || 294 mlir::isa<mlir::omp::YieldOp>(terminatorOp)) && 295 "expected terminator op"); 296 } 297 } 298 if (!eval.isDirective() && eval.hasNestedEvaluations()) 299 createEmptyRegionBlocks(firOpBuilder, eval.getNestedEvaluations()); 300 } 301 } 302 303 /// Create the body (block) for an OpenMP Operation. 304 /// 305 /// \param [in] op - the operation the body belongs to. 306 /// \param [inout] converter - converter to use for the clauses. 307 /// \param [in] loc - location in source code. 308 /// \param [in] eval - current PFT node/evaluation. 309 /// \oaran [in] clauses - list of clauses to process. 310 /// \param [in] args - block arguments (induction variable[s]) for the 311 //// region. 312 /// \param [in] outerCombined - is this an outer operation - prevents 313 /// privatization. 314 template <typename Op> 315 static void 316 createBodyOfOp(Op &op, Fortran::lower::AbstractConverter &converter, 317 mlir::Location &loc, Fortran::lower::pft::Evaluation &eval, 318 const Fortran::parser::OmpClauseList *clauses = nullptr, 319 const SmallVector<const Fortran::semantics::Symbol *> &args = {}, 320 bool outerCombined = false) { 321 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 322 // If an argument for the region is provided then create the block with that 323 // argument. Also update the symbol's address with the mlir argument value. 324 // e.g. For loops the argument is the induction variable. And all further 325 // uses of the induction variable should use this mlir value. 326 mlir::Operation *storeOp = nullptr; 327 if (args.size()) { 328 std::size_t loopVarTypeSize = 0; 329 for (const Fortran::semantics::Symbol *arg : args) 330 loopVarTypeSize = std::max(loopVarTypeSize, arg->GetUltimate().size()); 331 mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize); 332 SmallVector<Type> tiv; 333 SmallVector<Location> locs; 334 for (int i = 0; i < (int)args.size(); i++) { 335 tiv.push_back(loopVarType); 336 locs.push_back(loc); 337 } 338 firOpBuilder.createBlock(&op.getRegion(), {}, tiv, locs); 339 int argIndex = 0; 340 // The argument is not currently in memory, so make a temporary for the 341 // argument, and store it there, then bind that location to the argument. 342 for (const Fortran::semantics::Symbol *arg : args) { 343 mlir::Value val = 344 fir::getBase(op.getRegion().front().getArgument(argIndex)); 345 mlir::Value temp = firOpBuilder.createTemporary( 346 loc, loopVarType, 347 llvm::ArrayRef<mlir::NamedAttribute>{ 348 Fortran::lower::getAdaptToByRefAttr(firOpBuilder)}); 349 storeOp = firOpBuilder.create<fir::StoreOp>(loc, val, temp); 350 converter.bindSymbol(*arg, temp); 351 argIndex++; 352 } 353 } else { 354 firOpBuilder.createBlock(&op.getRegion()); 355 } 356 // Set the insert for the terminator operation to go at the end of the 357 // block - this is either empty or the block with the stores above, 358 // the end of the block works for both. 359 mlir::Block &block = op.getRegion().back(); 360 firOpBuilder.setInsertionPointToEnd(&block); 361 362 // If it is an unstructured region and is not the outer region of a combined 363 // construct, create empty blocks for all evaluations. 364 if (eval.lowerAsUnstructured() && !outerCombined) 365 createEmptyRegionBlocks(firOpBuilder, eval.getNestedEvaluations()); 366 367 // Insert the terminator. 368 if constexpr (std::is_same_v<Op, omp::WsLoopOp> || 369 std::is_same_v<Op, omp::SimdLoopOp>) { 370 mlir::ValueRange results; 371 firOpBuilder.create<mlir::omp::YieldOp>(loc, results); 372 } else { 373 firOpBuilder.create<mlir::omp::TerminatorOp>(loc); 374 } 375 376 // Reset the insert point to before the terminator. 377 if (storeOp) 378 firOpBuilder.setInsertionPointAfter(storeOp); 379 else 380 firOpBuilder.setInsertionPointToStart(&block); 381 382 // Handle privatization. Do not privatize if this is the outer operation. 383 if (clauses && !outerCombined) 384 privatizeVars(converter, *clauses); 385 386 if constexpr (std::is_same_v<Op, omp::ParallelOp>) { 387 threadPrivatizeVars(converter, eval); 388 if (clauses) 389 genCopyinClause(converter, *clauses); 390 } 391 } 392 393 static void genOMP(Fortran::lower::AbstractConverter &converter, 394 Fortran::lower::pft::Evaluation &eval, 395 const Fortran::parser::OpenMPSimpleStandaloneConstruct 396 &simpleStandaloneConstruct) { 397 const auto &directive = 398 std::get<Fortran::parser::OmpSimpleStandaloneDirective>( 399 simpleStandaloneConstruct.t); 400 switch (directive.v) { 401 default: 402 break; 403 case llvm::omp::Directive::OMPD_barrier: 404 converter.getFirOpBuilder().create<mlir::omp::BarrierOp>( 405 converter.getCurrentLocation()); 406 break; 407 case llvm::omp::Directive::OMPD_taskwait: 408 converter.getFirOpBuilder().create<mlir::omp::TaskwaitOp>( 409 converter.getCurrentLocation()); 410 break; 411 case llvm::omp::Directive::OMPD_taskyield: 412 converter.getFirOpBuilder().create<mlir::omp::TaskyieldOp>( 413 converter.getCurrentLocation()); 414 break; 415 case llvm::omp::Directive::OMPD_target_enter_data: 416 TODO(converter.getCurrentLocation(), "OMPD_target_enter_data"); 417 case llvm::omp::Directive::OMPD_target_exit_data: 418 TODO(converter.getCurrentLocation(), "OMPD_target_exit_data"); 419 case llvm::omp::Directive::OMPD_target_update: 420 TODO(converter.getCurrentLocation(), "OMPD_target_update"); 421 case llvm::omp::Directive::OMPD_ordered: 422 TODO(converter.getCurrentLocation(), "OMPD_ordered"); 423 } 424 } 425 426 static void 427 genAllocateClause(Fortran::lower::AbstractConverter &converter, 428 const Fortran::parser::OmpAllocateClause &ompAllocateClause, 429 SmallVector<Value> &allocatorOperands, 430 SmallVector<Value> &allocateOperands) { 431 auto &firOpBuilder = converter.getFirOpBuilder(); 432 auto currentLocation = converter.getCurrentLocation(); 433 Fortran::lower::StatementContext stmtCtx; 434 435 mlir::Value allocatorOperand; 436 const Fortran::parser::OmpObjectList &ompObjectList = 437 std::get<Fortran::parser::OmpObjectList>(ompAllocateClause.t); 438 const auto &allocatorValue = 439 std::get<std::optional<Fortran::parser::OmpAllocateClause::Allocator>>( 440 ompAllocateClause.t); 441 // Check if allocate clause has allocator specified. If so, add it 442 // to list of allocators, otherwise, add default allocator to 443 // list of allocators. 444 if (allocatorValue) { 445 allocatorOperand = fir::getBase(converter.genExprValue( 446 *Fortran::semantics::GetExpr(allocatorValue->v), stmtCtx)); 447 allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(), 448 allocatorOperand); 449 } else { 450 allocatorOperand = firOpBuilder.createIntegerConstant( 451 currentLocation, firOpBuilder.getI32Type(), 1); 452 allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(), 453 allocatorOperand); 454 } 455 genObjectList(ompObjectList, converter, allocateOperands); 456 } 457 458 static void 459 genOMP(Fortran::lower::AbstractConverter &converter, 460 Fortran::lower::pft::Evaluation &eval, 461 const Fortran::parser::OpenMPStandaloneConstruct &standaloneConstruct) { 462 std::visit( 463 Fortran::common::visitors{ 464 [&](const Fortran::parser::OpenMPSimpleStandaloneConstruct 465 &simpleStandaloneConstruct) { 466 genOMP(converter, eval, simpleStandaloneConstruct); 467 }, 468 [&](const Fortran::parser::OpenMPFlushConstruct &flushConstruct) { 469 SmallVector<Value, 4> operandRange; 470 if (const auto &ompObjectList = 471 std::get<std::optional<Fortran::parser::OmpObjectList>>( 472 flushConstruct.t)) 473 genObjectList(*ompObjectList, converter, operandRange); 474 const auto &memOrderClause = std::get<std::optional< 475 std::list<Fortran::parser::OmpMemoryOrderClause>>>( 476 flushConstruct.t); 477 if (memOrderClause.has_value() && memOrderClause->size() > 0) 478 TODO(converter.getCurrentLocation(), 479 "Handle OmpMemoryOrderClause"); 480 converter.getFirOpBuilder().create<mlir::omp::FlushOp>( 481 converter.getCurrentLocation(), operandRange); 482 }, 483 [&](const Fortran::parser::OpenMPCancelConstruct &cancelConstruct) { 484 TODO(converter.getCurrentLocation(), "OpenMPCancelConstruct"); 485 }, 486 [&](const Fortran::parser::OpenMPCancellationPointConstruct 487 &cancellationPointConstruct) { 488 TODO(converter.getCurrentLocation(), "OpenMPCancelConstruct"); 489 }, 490 }, 491 standaloneConstruct.u); 492 } 493 494 static omp::ClauseProcBindKindAttr genProcBindKindAttr( 495 fir::FirOpBuilder &firOpBuilder, 496 const Fortran::parser::OmpClause::ProcBind *procBindClause) { 497 omp::ClauseProcBindKind pbKind; 498 switch (procBindClause->v.v) { 499 case Fortran::parser::OmpProcBindClause::Type::Master: 500 pbKind = omp::ClauseProcBindKind::Master; 501 break; 502 case Fortran::parser::OmpProcBindClause::Type::Close: 503 pbKind = omp::ClauseProcBindKind::Close; 504 break; 505 case Fortran::parser::OmpProcBindClause::Type::Spread: 506 pbKind = omp::ClauseProcBindKind::Spread; 507 break; 508 case Fortran::parser::OmpProcBindClause::Type::Primary: 509 pbKind = omp::ClauseProcBindKind::Primary; 510 break; 511 } 512 return omp::ClauseProcBindKindAttr::get(firOpBuilder.getContext(), pbKind); 513 } 514 515 static mlir::Value 516 getIfClauseOperand(Fortran::lower::AbstractConverter &converter, 517 Fortran::lower::StatementContext &stmtCtx, 518 const Fortran::parser::OmpClause::If *ifClause) { 519 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 520 mlir::Location currentLocation = converter.getCurrentLocation(); 521 auto &expr = std::get<Fortran::parser::ScalarLogicalExpr>(ifClause->v.t); 522 mlir::Value ifVal = fir::getBase( 523 converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx)); 524 return firOpBuilder.createConvert(currentLocation, firOpBuilder.getI1Type(), 525 ifVal); 526 } 527 528 /* When parallel is used in a combined construct, then use this function to 529 * create the parallel operation. It handles the parallel specific clauses 530 * and leaves the rest for handling at the inner operations. 531 * TODO: Refactor clause handling 532 */ 533 template <typename Directive> 534 static void 535 createCombinedParallelOp(Fortran::lower::AbstractConverter &converter, 536 Fortran::lower::pft::Evaluation &eval, 537 const Directive &directive) { 538 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 539 mlir::Location currentLocation = converter.getCurrentLocation(); 540 Fortran::lower::StatementContext stmtCtx; 541 llvm::ArrayRef<mlir::Type> argTy; 542 mlir::Value ifClauseOperand, numThreadsClauseOperand; 543 SmallVector<Value> allocatorOperands, allocateOperands; 544 mlir::omp::ClauseProcBindKindAttr procBindKindAttr; 545 const auto &opClauseList = 546 std::get<Fortran::parser::OmpClauseList>(directive.t); 547 // TODO: Handle the following clauses 548 // 1. default 549 // Note: rest of the clauses are handled when the inner operation is created 550 for (const Fortran::parser::OmpClause &clause : opClauseList.v) { 551 if (const auto &ifClause = 552 std::get_if<Fortran::parser::OmpClause::If>(&clause.u)) { 553 ifClauseOperand = getIfClauseOperand(converter, stmtCtx, ifClause); 554 } else if (const auto &numThreadsClause = 555 std::get_if<Fortran::parser::OmpClause::NumThreads>( 556 &clause.u)) { 557 numThreadsClauseOperand = fir::getBase(converter.genExprValue( 558 *Fortran::semantics::GetExpr(numThreadsClause->v), stmtCtx)); 559 } else if (const auto &procBindClause = 560 std::get_if<Fortran::parser::OmpClause::ProcBind>( 561 &clause.u)) { 562 procBindKindAttr = genProcBindKindAttr(firOpBuilder, procBindClause); 563 } 564 } 565 // Create and insert the operation. 566 auto parallelOp = firOpBuilder.create<mlir::omp::ParallelOp>( 567 currentLocation, argTy, ifClauseOperand, numThreadsClauseOperand, 568 allocateOperands, allocatorOperands, /*reduction_vars=*/ValueRange(), 569 /*reductions=*/nullptr, procBindKindAttr); 570 571 createBodyOfOp<omp::ParallelOp>(parallelOp, converter, currentLocation, eval, 572 &opClauseList, /*iv=*/{}, 573 /*isCombined=*/true); 574 } 575 576 static void 577 genOMP(Fortran::lower::AbstractConverter &converter, 578 Fortran::lower::pft::Evaluation &eval, 579 const Fortran::parser::OpenMPBlockConstruct &blockConstruct) { 580 const auto &beginBlockDirective = 581 std::get<Fortran::parser::OmpBeginBlockDirective>(blockConstruct.t); 582 const auto &blockDirective = 583 std::get<Fortran::parser::OmpBlockDirective>(beginBlockDirective.t); 584 const auto &endBlockDirective = 585 std::get<Fortran::parser::OmpEndBlockDirective>(blockConstruct.t); 586 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 587 mlir::Location currentLocation = converter.getCurrentLocation(); 588 589 Fortran::lower::StatementContext stmtCtx; 590 llvm::ArrayRef<mlir::Type> argTy; 591 mlir::Value ifClauseOperand, numThreadsClauseOperand, finalClauseOperand, 592 priorityClauseOperand; 593 mlir::omp::ClauseProcBindKindAttr procBindKindAttr; 594 SmallVector<Value> allocateOperands, allocatorOperands; 595 mlir::UnitAttr nowaitAttr, untiedAttr, mergeableAttr; 596 597 const auto &opClauseList = 598 std::get<Fortran::parser::OmpClauseList>(beginBlockDirective.t); 599 for (const auto &clause : opClauseList.v) { 600 if (const auto &ifClause = 601 std::get_if<Fortran::parser::OmpClause::If>(&clause.u)) { 602 ifClauseOperand = getIfClauseOperand(converter, stmtCtx, ifClause); 603 } else if (const auto &numThreadsClause = 604 std::get_if<Fortran::parser::OmpClause::NumThreads>( 605 &clause.u)) { 606 // OMPIRBuilder expects `NUM_THREAD` clause as a `Value`. 607 numThreadsClauseOperand = fir::getBase(converter.genExprValue( 608 *Fortran::semantics::GetExpr(numThreadsClause->v), stmtCtx)); 609 } else if (const auto &procBindClause = 610 std::get_if<Fortran::parser::OmpClause::ProcBind>( 611 &clause.u)) { 612 procBindKindAttr = genProcBindKindAttr(firOpBuilder, procBindClause); 613 } else if (const auto &allocateClause = 614 std::get_if<Fortran::parser::OmpClause::Allocate>( 615 &clause.u)) { 616 genAllocateClause(converter, allocateClause->v, allocatorOperands, 617 allocateOperands); 618 } else if (std::get_if<Fortran::parser::OmpClause::Private>(&clause.u) || 619 std::get_if<Fortran::parser::OmpClause::Firstprivate>( 620 &clause.u) || 621 std::get_if<Fortran::parser::OmpClause::Copyin>(&clause.u)) { 622 // Privatisation and copyin clauses are handled elsewhere. 623 continue; 624 } else if (std::get_if<Fortran::parser::OmpClause::Shared>(&clause.u)) { 625 // Shared is the default behavior in the IR, so no handling is required. 626 continue; 627 } else if (const auto &defaultClause = 628 std::get_if<Fortran::parser::OmpClause::Default>( 629 &clause.u)) { 630 if ((defaultClause->v.v == 631 Fortran::parser::OmpDefaultClause::Type::Shared) || 632 (defaultClause->v.v == 633 Fortran::parser::OmpDefaultClause::Type::None)) { 634 // Default clause with shared or none do not require any handling since 635 // Shared is the default behavior in the IR and None is only required 636 // for semantic checks. 637 continue; 638 } 639 } else if (std::get_if<Fortran::parser::OmpClause::Threads>(&clause.u)) { 640 // Nothing needs to be done for threads clause. 641 continue; 642 } else if (const auto &finalClause = 643 std::get_if<Fortran::parser::OmpClause::Final>(&clause.u)) { 644 mlir::Value finalVal = fir::getBase(converter.genExprValue( 645 *Fortran::semantics::GetExpr(finalClause->v), stmtCtx)); 646 finalClauseOperand = firOpBuilder.createConvert( 647 currentLocation, firOpBuilder.getI1Type(), finalVal); 648 } else if (std::get_if<Fortran::parser::OmpClause::Untied>(&clause.u)) { 649 untiedAttr = firOpBuilder.getUnitAttr(); 650 } else if (std::get_if<Fortran::parser::OmpClause::Mergeable>(&clause.u)) { 651 mergeableAttr = firOpBuilder.getUnitAttr(); 652 } else if (const auto &priorityClause = 653 std::get_if<Fortran::parser::OmpClause::Priority>( 654 &clause.u)) { 655 priorityClauseOperand = fir::getBase(converter.genExprValue( 656 *Fortran::semantics::GetExpr(priorityClause->v), stmtCtx)); 657 } else { 658 TODO(currentLocation, "OpenMP Block construct clauses"); 659 } 660 } 661 662 for (const auto &clause : 663 std::get<Fortran::parser::OmpClauseList>(endBlockDirective.t).v) { 664 if (std::get_if<Fortran::parser::OmpClause::Nowait>(&clause.u)) 665 nowaitAttr = firOpBuilder.getUnitAttr(); 666 } 667 668 if (blockDirective.v == llvm::omp::OMPD_parallel) { 669 // Create and insert the operation. 670 auto parallelOp = firOpBuilder.create<mlir::omp::ParallelOp>( 671 currentLocation, argTy, ifClauseOperand, numThreadsClauseOperand, 672 allocateOperands, allocatorOperands, /*reduction_vars=*/ValueRange(), 673 /*reductions=*/nullptr, procBindKindAttr); 674 createBodyOfOp<omp::ParallelOp>(parallelOp, converter, currentLocation, 675 eval, &opClauseList); 676 } else if (blockDirective.v == llvm::omp::OMPD_master) { 677 auto masterOp = 678 firOpBuilder.create<mlir::omp::MasterOp>(currentLocation, argTy); 679 createBodyOfOp<omp::MasterOp>(masterOp, converter, currentLocation, eval); 680 } else if (blockDirective.v == llvm::omp::OMPD_single) { 681 auto singleOp = firOpBuilder.create<mlir::omp::SingleOp>( 682 currentLocation, allocateOperands, allocatorOperands, nowaitAttr); 683 createBodyOfOp<omp::SingleOp>(singleOp, converter, currentLocation, eval); 684 } else if (blockDirective.v == llvm::omp::OMPD_ordered) { 685 auto orderedOp = firOpBuilder.create<mlir::omp::OrderedRegionOp>( 686 currentLocation, /*simd=*/nullptr); 687 createBodyOfOp<omp::OrderedRegionOp>(orderedOp, converter, currentLocation, 688 eval); 689 } else if (blockDirective.v == llvm::omp::OMPD_task) { 690 auto taskOp = firOpBuilder.create<mlir::omp::TaskOp>( 691 currentLocation, ifClauseOperand, finalClauseOperand, untiedAttr, 692 mergeableAttr, /*in_reduction_vars=*/ValueRange(), 693 /*in_reductions=*/nullptr, priorityClauseOperand, allocateOperands, 694 allocatorOperands); 695 createBodyOfOp(taskOp, converter, currentLocation, eval, &opClauseList); 696 } else { 697 TODO(converter.getCurrentLocation(), "Unhandled block directive"); 698 } 699 } 700 701 /// Creates an OpenMP reduction declaration and inserts it into the provided 702 /// symbol table. The declaration has a constant initializer with the neutral 703 /// value `initValue`, and the reduction combiner carried over from `reduce`. 704 /// TODO: Generalize this for non-integer types, add atomic region. 705 static omp::ReductionDeclareOp createReductionDecl(fir::FirOpBuilder &builder, 706 llvm::StringRef name, 707 mlir::Type type, 708 mlir::Location loc) { 709 OpBuilder::InsertionGuard guard(builder); 710 mlir::ModuleOp module = builder.getModule(); 711 mlir::OpBuilder modBuilder(module.getBodyRegion()); 712 auto decl = module.lookupSymbol<mlir::omp::ReductionDeclareOp>(name); 713 if (!decl) 714 decl = modBuilder.create<omp::ReductionDeclareOp>(loc, name, type); 715 else 716 return decl; 717 718 builder.createBlock(&decl.initializerRegion(), decl.initializerRegion().end(), 719 {type}, {loc}); 720 builder.setInsertionPointToEnd(&decl.initializerRegion().back()); 721 Value init = builder.create<mlir::arith::ConstantOp>( 722 loc, type, builder.getIntegerAttr(type, 0)); 723 builder.create<omp::YieldOp>(loc, init); 724 725 builder.createBlock(&decl.reductionRegion(), decl.reductionRegion().end(), 726 {type, type}, {loc, loc}); 727 builder.setInsertionPointToEnd(&decl.reductionRegion().back()); 728 mlir::Value op1 = decl.reductionRegion().front().getArgument(0); 729 mlir::Value op2 = decl.reductionRegion().front().getArgument(1); 730 Value addRes = builder.create<mlir::arith::AddIOp>(loc, op1, op2); 731 builder.create<omp::YieldOp>(loc, addRes); 732 return decl; 733 } 734 735 static mlir::omp::ScheduleModifier 736 translateModifier(const Fortran::parser::OmpScheduleModifierType &m) { 737 switch (m.v) { 738 case Fortran::parser::OmpScheduleModifierType::ModType::Monotonic: 739 return mlir::omp::ScheduleModifier::monotonic; 740 case Fortran::parser::OmpScheduleModifierType::ModType::Nonmonotonic: 741 return mlir::omp::ScheduleModifier::nonmonotonic; 742 case Fortran::parser::OmpScheduleModifierType::ModType::Simd: 743 return mlir::omp::ScheduleModifier::simd; 744 } 745 return mlir::omp::ScheduleModifier::none; 746 } 747 748 static mlir::omp::ScheduleModifier 749 getScheduleModifier(const Fortran::parser::OmpScheduleClause &x) { 750 const auto &modifier = 751 std::get<std::optional<Fortran::parser::OmpScheduleModifier>>(x.t); 752 // The input may have the modifier any order, so we look for one that isn't 753 // SIMD. If modifier is not set at all, fall down to the bottom and return 754 // "none". 755 if (modifier) { 756 const auto &modType1 = 757 std::get<Fortran::parser::OmpScheduleModifier::Modifier1>(modifier->t); 758 if (modType1.v.v == 759 Fortran::parser::OmpScheduleModifierType::ModType::Simd) { 760 const auto &modType2 = std::get< 761 std::optional<Fortran::parser::OmpScheduleModifier::Modifier2>>( 762 modifier->t); 763 if (modType2 && 764 modType2->v.v != 765 Fortran::parser::OmpScheduleModifierType::ModType::Simd) 766 return translateModifier(modType2->v); 767 768 return mlir::omp::ScheduleModifier::none; 769 } 770 771 return translateModifier(modType1.v); 772 } 773 return mlir::omp::ScheduleModifier::none; 774 } 775 776 static mlir::omp::ScheduleModifier 777 getSIMDModifier(const Fortran::parser::OmpScheduleClause &x) { 778 const auto &modifier = 779 std::get<std::optional<Fortran::parser::OmpScheduleModifier>>(x.t); 780 // Either of the two possible modifiers in the input can be the SIMD modifier, 781 // so look in either one, and return simd if we find one. Not found = return 782 // "none". 783 if (modifier) { 784 const auto &modType1 = 785 std::get<Fortran::parser::OmpScheduleModifier::Modifier1>(modifier->t); 786 if (modType1.v.v == Fortran::parser::OmpScheduleModifierType::ModType::Simd) 787 return mlir::omp::ScheduleModifier::simd; 788 789 const auto &modType2 = std::get< 790 std::optional<Fortran::parser::OmpScheduleModifier::Modifier2>>( 791 modifier->t); 792 if (modType2 && modType2->v.v == 793 Fortran::parser::OmpScheduleModifierType::ModType::Simd) 794 return mlir::omp::ScheduleModifier::simd; 795 } 796 return mlir::omp::ScheduleModifier::none; 797 } 798 799 static std::string getReductionName( 800 Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp, 801 mlir::Type ty) { 802 std::string reductionName; 803 if (intrinsicOp == Fortran::parser::DefinedOperator::IntrinsicOperator::Add) 804 reductionName = "add_reduction"; 805 else 806 reductionName = "other_reduction"; 807 808 return (llvm::Twine(reductionName) + 809 (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) + 810 llvm::Twine(ty.getIntOrFloatBitWidth())) 811 .str(); 812 } 813 814 static void genOMP(Fortran::lower::AbstractConverter &converter, 815 Fortran::lower::pft::Evaluation &eval, 816 const Fortran::parser::OpenMPLoopConstruct &loopConstruct) { 817 818 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 819 mlir::Location currentLocation = converter.getCurrentLocation(); 820 llvm::SmallVector<mlir::Value> lowerBound, upperBound, step, linearVars, 821 linearStepVars, reductionVars; 822 mlir::Value scheduleChunkClauseOperand, ifClauseOperand; 823 mlir::Attribute scheduleClauseOperand, noWaitClauseOperand, 824 orderedClauseOperand, orderClauseOperand; 825 SmallVector<Attribute> reductionDeclSymbols; 826 Fortran::lower::StatementContext stmtCtx; 827 const auto &loopOpClauseList = std::get<Fortran::parser::OmpClauseList>( 828 std::get<Fortran::parser::OmpBeginLoopDirective>(loopConstruct.t).t); 829 830 const auto ompDirective = 831 std::get<Fortran::parser::OmpLoopDirective>( 832 std::get<Fortran::parser::OmpBeginLoopDirective>(loopConstruct.t).t) 833 .v; 834 if (llvm::omp::OMPD_parallel_do == ompDirective) { 835 createCombinedParallelOp<Fortran::parser::OmpBeginLoopDirective>( 836 converter, eval, 837 std::get<Fortran::parser::OmpBeginLoopDirective>(loopConstruct.t)); 838 } else if (llvm::omp::OMPD_do != ompDirective && 839 llvm::omp::OMPD_simd != ompDirective) { 840 TODO(converter.getCurrentLocation(), "Construct enclosing do loop"); 841 } 842 843 // Collect the loops to collapse. 844 auto *doConstructEval = &eval.getFirstNestedEvaluation(); 845 846 std::int64_t collapseValue = 847 Fortran::lower::getCollapseValue(loopOpClauseList); 848 std::size_t loopVarTypeSize = 0; 849 SmallVector<const Fortran::semantics::Symbol *> iv; 850 do { 851 auto *doLoop = &doConstructEval->getFirstNestedEvaluation(); 852 auto *doStmt = doLoop->getIf<Fortran::parser::NonLabelDoStmt>(); 853 assert(doStmt && "Expected do loop to be in the nested evaluation"); 854 const auto &loopControl = 855 std::get<std::optional<Fortran::parser::LoopControl>>(doStmt->t); 856 const Fortran::parser::LoopControl::Bounds *bounds = 857 std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u); 858 assert(bounds && "Expected bounds for worksharing do loop"); 859 Fortran::lower::StatementContext stmtCtx; 860 lowerBound.push_back(fir::getBase(converter.genExprValue( 861 *Fortran::semantics::GetExpr(bounds->lower), stmtCtx))); 862 upperBound.push_back(fir::getBase(converter.genExprValue( 863 *Fortran::semantics::GetExpr(bounds->upper), stmtCtx))); 864 if (bounds->step) { 865 step.push_back(fir::getBase(converter.genExprValue( 866 *Fortran::semantics::GetExpr(bounds->step), stmtCtx))); 867 } else { // If `step` is not present, assume it as `1`. 868 step.push_back(firOpBuilder.createIntegerConstant( 869 currentLocation, firOpBuilder.getIntegerType(32), 1)); 870 } 871 iv.push_back(bounds->name.thing.symbol); 872 loopVarTypeSize = std::max(loopVarTypeSize, 873 bounds->name.thing.symbol->GetUltimate().size()); 874 875 collapseValue--; 876 doConstructEval = 877 &*std::next(doConstructEval->getNestedEvaluations().begin()); 878 } while (collapseValue > 0); 879 880 for (const auto &clause : loopOpClauseList.v) { 881 if (const auto &scheduleClause = 882 std::get_if<Fortran::parser::OmpClause::Schedule>(&clause.u)) { 883 if (const auto &chunkExpr = 884 std::get<std::optional<Fortran::parser::ScalarIntExpr>>( 885 scheduleClause->v.t)) { 886 if (const auto *expr = Fortran::semantics::GetExpr(*chunkExpr)) { 887 scheduleChunkClauseOperand = 888 fir::getBase(converter.genExprValue(*expr, stmtCtx)); 889 } 890 } 891 } else if (const auto &ifClause = 892 std::get_if<Fortran::parser::OmpClause::If>(&clause.u)) { 893 ifClauseOperand = getIfClauseOperand(converter, stmtCtx, ifClause); 894 } else if (const auto &reductionClause = 895 std::get_if<Fortran::parser::OmpClause::Reduction>( 896 &clause.u)) { 897 omp::ReductionDeclareOp decl; 898 const auto &redOperator{std::get<Fortran::parser::OmpReductionOperator>( 899 reductionClause->v.t)}; 900 const auto &objectList{ 901 std::get<Fortran::parser::OmpObjectList>(reductionClause->v.t)}; 902 if (const auto &redDefinedOp = 903 std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) { 904 const auto &intrinsicOp{ 905 std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>( 906 redDefinedOp->u)}; 907 if (intrinsicOp != 908 Fortran::parser::DefinedOperator::IntrinsicOperator::Add) 909 TODO(currentLocation, 910 "Reduction of some intrinsic operators is not supported"); 911 for (const auto &ompObject : objectList.v) { 912 if (const auto *name{ 913 Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) { 914 if (const auto *symbol{name->symbol}) { 915 mlir::Value symVal = converter.getSymbolAddress(*symbol); 916 mlir::Type redType = 917 symVal.getType().cast<fir::ReferenceType>().getEleTy(); 918 reductionVars.push_back(symVal); 919 if (redType.isIntOrIndex()) { 920 decl = createReductionDecl( 921 firOpBuilder, getReductionName(intrinsicOp, redType), 922 redType, currentLocation); 923 } else { 924 TODO(currentLocation, 925 "Reduction of some types is not supported"); 926 } 927 reductionDeclSymbols.push_back(SymbolRefAttr::get( 928 firOpBuilder.getContext(), decl.sym_name())); 929 } 930 } 931 } 932 } else { 933 TODO(currentLocation, 934 "Reduction of intrinsic procedures is not supported"); 935 } 936 } 937 } 938 939 // The types of lower bound, upper bound, and step are converted into the 940 // type of the loop variable if necessary. 941 mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize); 942 for (unsigned it = 0; it < (unsigned)lowerBound.size(); it++) { 943 lowerBound[it] = firOpBuilder.createConvert(currentLocation, loopVarType, 944 lowerBound[it]); 945 upperBound[it] = firOpBuilder.createConvert(currentLocation, loopVarType, 946 upperBound[it]); 947 step[it] = 948 firOpBuilder.createConvert(currentLocation, loopVarType, step[it]); 949 } 950 951 // 2.9.3.1 SIMD construct 952 // TODO: Support all the clauses 953 if (llvm::omp::OMPD_simd == ompDirective) { 954 TypeRange resultType; 955 auto SimdLoopOp = firOpBuilder.create<mlir::omp::SimdLoopOp>( 956 currentLocation, resultType, lowerBound, upperBound, step, 957 ifClauseOperand, /*inclusive=*/firOpBuilder.getUnitAttr()); 958 createBodyOfOp<omp::SimdLoopOp>(SimdLoopOp, converter, currentLocation, 959 eval, &loopOpClauseList, iv); 960 return; 961 } 962 963 // FIXME: Add support for following clauses: 964 // 1. linear 965 // 2. order 966 auto wsLoopOp = firOpBuilder.create<mlir::omp::WsLoopOp>( 967 currentLocation, lowerBound, upperBound, step, linearVars, linearStepVars, 968 reductionVars, 969 reductionDeclSymbols.empty() 970 ? nullptr 971 : mlir::ArrayAttr::get(firOpBuilder.getContext(), 972 reductionDeclSymbols), 973 scheduleClauseOperand.dyn_cast_or_null<omp::ClauseScheduleKindAttr>(), 974 scheduleChunkClauseOperand, /*schedule_modifiers=*/nullptr, 975 /*simd_modifier=*/nullptr, 976 noWaitClauseOperand.dyn_cast_or_null<UnitAttr>(), 977 orderedClauseOperand.dyn_cast_or_null<IntegerAttr>(), 978 orderClauseOperand.dyn_cast_or_null<omp::ClauseOrderKindAttr>(), 979 /*inclusive=*/firOpBuilder.getUnitAttr()); 980 981 // Handle attribute based clauses. 982 for (const Fortran::parser::OmpClause &clause : loopOpClauseList.v) { 983 if (const auto &orderedClause = 984 std::get_if<Fortran::parser::OmpClause::Ordered>(&clause.u)) { 985 if (orderedClause->v.has_value()) { 986 const auto *expr = Fortran::semantics::GetExpr(orderedClause->v); 987 const std::optional<std::int64_t> orderedClauseValue = 988 Fortran::evaluate::ToInt64(*expr); 989 wsLoopOp.ordered_valAttr( 990 firOpBuilder.getI64IntegerAttr(*orderedClauseValue)); 991 } else { 992 wsLoopOp.ordered_valAttr(firOpBuilder.getI64IntegerAttr(0)); 993 } 994 } else if (const auto &scheduleClause = 995 std::get_if<Fortran::parser::OmpClause::Schedule>( 996 &clause.u)) { 997 mlir::MLIRContext *context = firOpBuilder.getContext(); 998 const auto &scheduleType = scheduleClause->v; 999 const auto &scheduleKind = 1000 std::get<Fortran::parser::OmpScheduleClause::ScheduleType>( 1001 scheduleType.t); 1002 switch (scheduleKind) { 1003 case Fortran::parser::OmpScheduleClause::ScheduleType::Static: 1004 wsLoopOp.schedule_valAttr(omp::ClauseScheduleKindAttr::get( 1005 context, omp::ClauseScheduleKind::Static)); 1006 break; 1007 case Fortran::parser::OmpScheduleClause::ScheduleType::Dynamic: 1008 wsLoopOp.schedule_valAttr(omp::ClauseScheduleKindAttr::get( 1009 context, omp::ClauseScheduleKind::Dynamic)); 1010 break; 1011 case Fortran::parser::OmpScheduleClause::ScheduleType::Guided: 1012 wsLoopOp.schedule_valAttr(omp::ClauseScheduleKindAttr::get( 1013 context, omp::ClauseScheduleKind::Guided)); 1014 break; 1015 case Fortran::parser::OmpScheduleClause::ScheduleType::Auto: 1016 wsLoopOp.schedule_valAttr(omp::ClauseScheduleKindAttr::get( 1017 context, omp::ClauseScheduleKind::Auto)); 1018 break; 1019 case Fortran::parser::OmpScheduleClause::ScheduleType::Runtime: 1020 wsLoopOp.schedule_valAttr(omp::ClauseScheduleKindAttr::get( 1021 context, omp::ClauseScheduleKind::Runtime)); 1022 break; 1023 } 1024 mlir::omp::ScheduleModifier scheduleModifier = 1025 getScheduleModifier(scheduleClause->v); 1026 if (scheduleModifier != mlir::omp::ScheduleModifier::none) 1027 wsLoopOp.schedule_modifierAttr( 1028 omp::ScheduleModifierAttr::get(context, scheduleModifier)); 1029 if (getSIMDModifier(scheduleClause->v) != 1030 mlir::omp::ScheduleModifier::none) 1031 wsLoopOp.simd_modifierAttr(firOpBuilder.getUnitAttr()); 1032 } 1033 } 1034 // In FORTRAN `nowait` clause occur at the end of `omp do` directive. 1035 // i.e 1036 // !$omp do 1037 // <...> 1038 // !$omp end do nowait 1039 if (const auto &endClauseList = 1040 std::get<std::optional<Fortran::parser::OmpEndLoopDirective>>( 1041 loopConstruct.t)) { 1042 const auto &clauseList = 1043 std::get<Fortran::parser::OmpClauseList>((*endClauseList).t); 1044 for (const Fortran::parser::OmpClause &clause : clauseList.v) 1045 if (std::get_if<Fortran::parser::OmpClause::Nowait>(&clause.u)) 1046 wsLoopOp.nowaitAttr(firOpBuilder.getUnitAttr()); 1047 } 1048 1049 createBodyOfOp<omp::WsLoopOp>(wsLoopOp, converter, currentLocation, eval, 1050 &loopOpClauseList, iv); 1051 } 1052 1053 static void 1054 genOMP(Fortran::lower::AbstractConverter &converter, 1055 Fortran::lower::pft::Evaluation &eval, 1056 const Fortran::parser::OpenMPCriticalConstruct &criticalConstruct) { 1057 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 1058 mlir::Location currentLocation = converter.getCurrentLocation(); 1059 std::string name; 1060 const Fortran::parser::OmpCriticalDirective &cd = 1061 std::get<Fortran::parser::OmpCriticalDirective>(criticalConstruct.t); 1062 if (std::get<std::optional<Fortran::parser::Name>>(cd.t).has_value()) { 1063 name = 1064 std::get<std::optional<Fortran::parser::Name>>(cd.t).value().ToString(); 1065 } 1066 1067 uint64_t hint = 0; 1068 const auto &clauseList = std::get<Fortran::parser::OmpClauseList>(cd.t); 1069 for (const Fortran::parser::OmpClause &clause : clauseList.v) 1070 if (auto hintClause = 1071 std::get_if<Fortran::parser::OmpClause::Hint>(&clause.u)) { 1072 const auto *expr = Fortran::semantics::GetExpr(hintClause->v); 1073 hint = *Fortran::evaluate::ToInt64(*expr); 1074 break; 1075 } 1076 1077 mlir::omp::CriticalOp criticalOp = [&]() { 1078 if (name.empty()) { 1079 return firOpBuilder.create<mlir::omp::CriticalOp>(currentLocation, 1080 FlatSymbolRefAttr()); 1081 } else { 1082 mlir::ModuleOp module = firOpBuilder.getModule(); 1083 mlir::OpBuilder modBuilder(module.getBodyRegion()); 1084 auto global = module.lookupSymbol<mlir::omp::CriticalDeclareOp>(name); 1085 if (!global) 1086 global = modBuilder.create<mlir::omp::CriticalDeclareOp>( 1087 currentLocation, name, hint); 1088 return firOpBuilder.create<mlir::omp::CriticalOp>( 1089 currentLocation, mlir::FlatSymbolRefAttr::get( 1090 firOpBuilder.getContext(), global.sym_name())); 1091 } 1092 }(); 1093 createBodyOfOp<omp::CriticalOp>(criticalOp, converter, currentLocation, eval); 1094 } 1095 1096 static void 1097 genOMP(Fortran::lower::AbstractConverter &converter, 1098 Fortran::lower::pft::Evaluation &eval, 1099 const Fortran::parser::OpenMPSectionConstruct §ionConstruct) { 1100 1101 auto &firOpBuilder = converter.getFirOpBuilder(); 1102 auto currentLocation = converter.getCurrentLocation(); 1103 mlir::omp::SectionOp sectionOp = 1104 firOpBuilder.create<mlir::omp::SectionOp>(currentLocation); 1105 createBodyOfOp<omp::SectionOp>(sectionOp, converter, currentLocation, eval); 1106 } 1107 1108 // TODO: Add support for reduction 1109 static void 1110 genOMP(Fortran::lower::AbstractConverter &converter, 1111 Fortran::lower::pft::Evaluation &eval, 1112 const Fortran::parser::OpenMPSectionsConstruct §ionsConstruct) { 1113 auto &firOpBuilder = converter.getFirOpBuilder(); 1114 auto currentLocation = converter.getCurrentLocation(); 1115 SmallVector<Value> reductionVars, allocateOperands, allocatorOperands; 1116 mlir::UnitAttr noWaitClauseOperand; 1117 const auto §ionsClauseList = std::get<Fortran::parser::OmpClauseList>( 1118 std::get<Fortran::parser::OmpBeginSectionsDirective>(sectionsConstruct.t) 1119 .t); 1120 for (const Fortran::parser::OmpClause &clause : sectionsClauseList.v) { 1121 1122 // Reduction Clause 1123 if (std::get_if<Fortran::parser::OmpClause::Reduction>(&clause.u)) { 1124 TODO(currentLocation, "OMPC_Reduction"); 1125 1126 // Allocate clause 1127 } else if (const auto &allocateClause = 1128 std::get_if<Fortran::parser::OmpClause::Allocate>( 1129 &clause.u)) { 1130 genAllocateClause(converter, allocateClause->v, allocatorOperands, 1131 allocateOperands); 1132 } 1133 } 1134 const auto &endSectionsClauseList = 1135 std::get<Fortran::parser::OmpEndSectionsDirective>(sectionsConstruct.t); 1136 const auto &clauseList = 1137 std::get<Fortran::parser::OmpClauseList>(endSectionsClauseList.t); 1138 for (const auto &clause : clauseList.v) { 1139 // Nowait clause 1140 if (std::get_if<Fortran::parser::OmpClause::Nowait>(&clause.u)) { 1141 noWaitClauseOperand = firOpBuilder.getUnitAttr(); 1142 } 1143 } 1144 1145 llvm::omp::Directive dir = 1146 std::get<Fortran::parser::OmpSectionsDirective>( 1147 std::get<Fortran::parser::OmpBeginSectionsDirective>( 1148 sectionsConstruct.t) 1149 .t) 1150 .v; 1151 1152 // Parallel Sections Construct 1153 if (dir == llvm::omp::Directive::OMPD_parallel_sections) { 1154 createCombinedParallelOp<Fortran::parser::OmpBeginSectionsDirective>( 1155 converter, eval, 1156 std::get<Fortran::parser::OmpBeginSectionsDirective>( 1157 sectionsConstruct.t)); 1158 auto sectionsOp = firOpBuilder.create<mlir::omp::SectionsOp>( 1159 currentLocation, /*reduction_vars*/ ValueRange(), 1160 /*reductions=*/nullptr, allocateOperands, allocatorOperands, 1161 /*nowait=*/nullptr); 1162 createBodyOfOp(sectionsOp, converter, currentLocation, eval); 1163 1164 // Sections Construct 1165 } else if (dir == llvm::omp::Directive::OMPD_sections) { 1166 auto sectionsOp = firOpBuilder.create<mlir::omp::SectionsOp>( 1167 currentLocation, reductionVars, /*reductions = */ nullptr, 1168 allocateOperands, allocatorOperands, noWaitClauseOperand); 1169 createBodyOfOp<omp::SectionsOp>(sectionsOp, converter, currentLocation, 1170 eval); 1171 } 1172 } 1173 1174 static void genOmpAtomicHintAndMemoryOrderClauses( 1175 Fortran::lower::AbstractConverter &converter, 1176 const Fortran::parser::OmpAtomicClauseList &clauseList, 1177 mlir::IntegerAttr &hint, 1178 mlir::omp::ClauseMemoryOrderKindAttr &memory_order) { 1179 auto &firOpBuilder = converter.getFirOpBuilder(); 1180 for (const auto &clause : clauseList.v) { 1181 if (auto ompClause = std::get_if<Fortran::parser::OmpClause>(&clause.u)) { 1182 if (auto hintClause = 1183 std::get_if<Fortran::parser::OmpClause::Hint>(&ompClause->u)) { 1184 const auto *expr = Fortran::semantics::GetExpr(hintClause->v); 1185 uint64_t hintExprValue = *Fortran::evaluate::ToInt64(*expr); 1186 hint = firOpBuilder.getI64IntegerAttr(hintExprValue); 1187 } 1188 } else if (auto ompMemoryOrderClause = 1189 std::get_if<Fortran::parser::OmpMemoryOrderClause>( 1190 &clause.u)) { 1191 if (std::get_if<Fortran::parser::OmpClause::Acquire>( 1192 &ompMemoryOrderClause->v.u)) { 1193 memory_order = mlir::omp::ClauseMemoryOrderKindAttr::get( 1194 firOpBuilder.getContext(), omp::ClauseMemoryOrderKind::Acquire); 1195 } else if (std::get_if<Fortran::parser::OmpClause::Relaxed>( 1196 &ompMemoryOrderClause->v.u)) { 1197 memory_order = mlir::omp::ClauseMemoryOrderKindAttr::get( 1198 firOpBuilder.getContext(), omp::ClauseMemoryOrderKind::Relaxed); 1199 } else if (std::get_if<Fortran::parser::OmpClause::SeqCst>( 1200 &ompMemoryOrderClause->v.u)) { 1201 memory_order = mlir::omp::ClauseMemoryOrderKindAttr::get( 1202 firOpBuilder.getContext(), omp::ClauseMemoryOrderKind::Seq_cst); 1203 } else if (std::get_if<Fortran::parser::OmpClause::Release>( 1204 &ompMemoryOrderClause->v.u)) { 1205 memory_order = mlir::omp::ClauseMemoryOrderKindAttr::get( 1206 firOpBuilder.getContext(), omp::ClauseMemoryOrderKind::Release); 1207 } 1208 } 1209 } 1210 } 1211 1212 static void genOmpAtomicUpdateStatement( 1213 Fortran::lower::AbstractConverter &converter, 1214 Fortran::lower::pft::Evaluation &eval, 1215 const Fortran::parser::Variable &assignmentStmtVariable, 1216 const Fortran::parser::Expr &assignmentStmtExpr, 1217 const Fortran::parser::OmpAtomicClauseList *leftHandClauseList, 1218 const Fortran::parser::OmpAtomicClauseList *rightHandClauseList) { 1219 // Generate `omp.atomic.update` operation for atomic assignment statements 1220 auto &firOpBuilder = converter.getFirOpBuilder(); 1221 auto currentLocation = converter.getCurrentLocation(); 1222 Fortran::lower::StatementContext stmtCtx; 1223 1224 mlir::Value address = fir::getBase(converter.genExprAddr( 1225 *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx)); 1226 // If no hint clause is specified, the effect is as if 1227 // hint(omp_sync_hint_none) had been specified. 1228 mlir::IntegerAttr hint = nullptr; 1229 mlir::omp::ClauseMemoryOrderKindAttr memory_order = nullptr; 1230 if (leftHandClauseList) 1231 genOmpAtomicHintAndMemoryOrderClauses(converter, *leftHandClauseList, hint, 1232 memory_order); 1233 if (rightHandClauseList) 1234 genOmpAtomicHintAndMemoryOrderClauses(converter, *rightHandClauseList, hint, 1235 memory_order); 1236 auto atomicUpdateOp = firOpBuilder.create<mlir::omp::AtomicUpdateOp>( 1237 currentLocation, address, hint, memory_order); 1238 1239 //// Generate body of Atomic Update operation 1240 // If an argument for the region is provided then create the block with that 1241 // argument. Also update the symbol's address with the argument mlir value. 1242 mlir::Type varType = 1243 fir::getBase( 1244 converter.genExprValue( 1245 *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx)) 1246 .getType(); 1247 SmallVector<Type> varTys = {varType}; 1248 SmallVector<Location> locs = {currentLocation}; 1249 firOpBuilder.createBlock(&atomicUpdateOp.getRegion(), {}, varTys, locs); 1250 mlir::Value val = 1251 fir::getBase(atomicUpdateOp.getRegion().front().getArgument(0)); 1252 auto varDesignator = 1253 std::get_if<Fortran::common::Indirection<Fortran::parser::Designator>>( 1254 &assignmentStmtVariable.u); 1255 assert(varDesignator && "Variable designator for atomic update assignment " 1256 "statement does not exist"); 1257 const auto *name = getDesignatorNameIfDataRef(varDesignator->value()); 1258 assert(name && name->symbol && 1259 "No symbol attached to atomic update variable"); 1260 converter.bindSymbol(*name->symbol, val); 1261 // Set the insert for the terminator operation to go at the end of the 1262 // block. 1263 mlir::Block &block = atomicUpdateOp.getRegion().back(); 1264 firOpBuilder.setInsertionPointToEnd(&block); 1265 1266 mlir::Value result = fir::getBase(converter.genExprValue( 1267 *Fortran::semantics::GetExpr(assignmentStmtExpr), stmtCtx)); 1268 // Insert the terminator: YieldOp. 1269 firOpBuilder.create<mlir::omp::YieldOp>(currentLocation, result); 1270 // Reset the insert point to before the terminator. 1271 firOpBuilder.setInsertionPointToStart(&block); 1272 } 1273 1274 static void 1275 genOmpAtomicWrite(Fortran::lower::AbstractConverter &converter, 1276 Fortran::lower::pft::Evaluation &eval, 1277 const Fortran::parser::OmpAtomicWrite &atomicWrite) { 1278 auto &firOpBuilder = converter.getFirOpBuilder(); 1279 auto currentLocation = converter.getCurrentLocation(); 1280 // Get the value and address of atomic write operands. 1281 const Fortran::parser::OmpAtomicClauseList &rightHandClauseList = 1282 std::get<2>(atomicWrite.t); 1283 const Fortran::parser::OmpAtomicClauseList &leftHandClauseList = 1284 std::get<0>(atomicWrite.t); 1285 const auto &assignmentStmtExpr = 1286 std::get<Fortran::parser::Expr>(std::get<3>(atomicWrite.t).statement.t); 1287 const auto &assignmentStmtVariable = std::get<Fortran::parser::Variable>( 1288 std::get<3>(atomicWrite.t).statement.t); 1289 Fortran::lower::StatementContext stmtCtx; 1290 mlir::Value value = fir::getBase(converter.genExprValue( 1291 *Fortran::semantics::GetExpr(assignmentStmtExpr), stmtCtx)); 1292 mlir::Value address = fir::getBase(converter.genExprAddr( 1293 *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx)); 1294 // If no hint clause is specified, the effect is as if 1295 // hint(omp_sync_hint_none) had been specified. 1296 mlir::IntegerAttr hint = nullptr; 1297 mlir::omp::ClauseMemoryOrderKindAttr memory_order = nullptr; 1298 genOmpAtomicHintAndMemoryOrderClauses(converter, leftHandClauseList, hint, 1299 memory_order); 1300 genOmpAtomicHintAndMemoryOrderClauses(converter, rightHandClauseList, hint, 1301 memory_order); 1302 firOpBuilder.create<mlir::omp::AtomicWriteOp>(currentLocation, address, value, 1303 hint, memory_order); 1304 } 1305 1306 static void genOmpAtomicRead(Fortran::lower::AbstractConverter &converter, 1307 Fortran::lower::pft::Evaluation &eval, 1308 const Fortran::parser::OmpAtomicRead &atomicRead) { 1309 auto &firOpBuilder = converter.getFirOpBuilder(); 1310 auto currentLocation = converter.getCurrentLocation(); 1311 // Get the address of atomic read operands. 1312 const Fortran::parser::OmpAtomicClauseList &rightHandClauseList = 1313 std::get<2>(atomicRead.t); 1314 const Fortran::parser::OmpAtomicClauseList &leftHandClauseList = 1315 std::get<0>(atomicRead.t); 1316 const auto &assignmentStmtExpr = 1317 std::get<Fortran::parser::Expr>(std::get<3>(atomicRead.t).statement.t); 1318 const auto &assignmentStmtVariable = std::get<Fortran::parser::Variable>( 1319 std::get<3>(atomicRead.t).statement.t); 1320 Fortran::lower::StatementContext stmtCtx; 1321 mlir::Value from_address = fir::getBase(converter.genExprAddr( 1322 *Fortran::semantics::GetExpr(assignmentStmtExpr), stmtCtx)); 1323 mlir::Value to_address = fir::getBase(converter.genExprAddr( 1324 *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx)); 1325 // If no hint clause is specified, the effect is as if 1326 // hint(omp_sync_hint_none) had been specified. 1327 mlir::IntegerAttr hint = nullptr; 1328 mlir::omp::ClauseMemoryOrderKindAttr memory_order = nullptr; 1329 genOmpAtomicHintAndMemoryOrderClauses(converter, leftHandClauseList, hint, 1330 memory_order); 1331 genOmpAtomicHintAndMemoryOrderClauses(converter, rightHandClauseList, hint, 1332 memory_order); 1333 firOpBuilder.create<mlir::omp::AtomicReadOp>(currentLocation, from_address, 1334 to_address, hint, memory_order); 1335 } 1336 1337 static void 1338 genOmpAtomicUpdate(Fortran::lower::AbstractConverter &converter, 1339 Fortran::lower::pft::Evaluation &eval, 1340 const Fortran::parser::OmpAtomicUpdate &atomicUpdate) { 1341 const Fortran::parser::OmpAtomicClauseList &rightHandClauseList = 1342 std::get<2>(atomicUpdate.t); 1343 const Fortran::parser::OmpAtomicClauseList &leftHandClauseList = 1344 std::get<0>(atomicUpdate.t); 1345 const auto &assignmentStmtExpr = 1346 std::get<Fortran::parser::Expr>(std::get<3>(atomicUpdate.t).statement.t); 1347 const auto &assignmentStmtVariable = std::get<Fortran::parser::Variable>( 1348 std::get<3>(atomicUpdate.t).statement.t); 1349 1350 genOmpAtomicUpdateStatement(converter, eval, assignmentStmtVariable, 1351 assignmentStmtExpr, &leftHandClauseList, 1352 &rightHandClauseList); 1353 } 1354 1355 static void genOmpAtomic(Fortran::lower::AbstractConverter &converter, 1356 Fortran::lower::pft::Evaluation &eval, 1357 const Fortran::parser::OmpAtomic &atomicConstruct) { 1358 const Fortran::parser::OmpAtomicClauseList &atomicClauseList = 1359 std::get<Fortran::parser::OmpAtomicClauseList>(atomicConstruct.t); 1360 const auto &assignmentStmtExpr = std::get<Fortran::parser::Expr>( 1361 std::get<Fortran::parser::Statement<Fortran::parser::AssignmentStmt>>( 1362 atomicConstruct.t) 1363 .statement.t); 1364 const auto &assignmentStmtVariable = std::get<Fortran::parser::Variable>( 1365 std::get<Fortran::parser::Statement<Fortran::parser::AssignmentStmt>>( 1366 atomicConstruct.t) 1367 .statement.t); 1368 // If atomic-clause is not present on the construct, the behaviour is as if 1369 // the update clause is specified 1370 genOmpAtomicUpdateStatement(converter, eval, assignmentStmtVariable, 1371 assignmentStmtExpr, &atomicClauseList, nullptr); 1372 } 1373 1374 static void 1375 genOMP(Fortran::lower::AbstractConverter &converter, 1376 Fortran::lower::pft::Evaluation &eval, 1377 const Fortran::parser::OpenMPAtomicConstruct &atomicConstruct) { 1378 std::visit(Fortran::common::visitors{ 1379 [&](const Fortran::parser::OmpAtomicRead &atomicRead) { 1380 genOmpAtomicRead(converter, eval, atomicRead); 1381 }, 1382 [&](const Fortran::parser::OmpAtomicWrite &atomicWrite) { 1383 genOmpAtomicWrite(converter, eval, atomicWrite); 1384 }, 1385 [&](const Fortran::parser::OmpAtomic &atomicConstruct) { 1386 genOmpAtomic(converter, eval, atomicConstruct); 1387 }, 1388 [&](const Fortran::parser::OmpAtomicUpdate &atomicUpdate) { 1389 genOmpAtomicUpdate(converter, eval, atomicUpdate); 1390 }, 1391 [&](const auto &) { 1392 TODO(converter.getCurrentLocation(), "Atomic capture"); 1393 }, 1394 }, 1395 atomicConstruct.u); 1396 } 1397 1398 void Fortran::lower::genOpenMPConstruct( 1399 Fortran::lower::AbstractConverter &converter, 1400 Fortran::lower::pft::Evaluation &eval, 1401 const Fortran::parser::OpenMPConstruct &ompConstruct) { 1402 1403 std::visit( 1404 common::visitors{ 1405 [&](const Fortran::parser::OpenMPStandaloneConstruct 1406 &standaloneConstruct) { 1407 genOMP(converter, eval, standaloneConstruct); 1408 }, 1409 [&](const Fortran::parser::OpenMPSectionsConstruct 1410 §ionsConstruct) { 1411 genOMP(converter, eval, sectionsConstruct); 1412 }, 1413 [&](const Fortran::parser::OpenMPSectionConstruct §ionConstruct) { 1414 genOMP(converter, eval, sectionConstruct); 1415 }, 1416 [&](const Fortran::parser::OpenMPLoopConstruct &loopConstruct) { 1417 genOMP(converter, eval, loopConstruct); 1418 }, 1419 [&](const Fortran::parser::OpenMPDeclarativeAllocate 1420 &execAllocConstruct) { 1421 TODO(converter.getCurrentLocation(), "OpenMPDeclarativeAllocate"); 1422 }, 1423 [&](const Fortran::parser::OpenMPExecutableAllocate 1424 &execAllocConstruct) { 1425 TODO(converter.getCurrentLocation(), "OpenMPExecutableAllocate"); 1426 }, 1427 [&](const Fortran::parser::OpenMPBlockConstruct &blockConstruct) { 1428 genOMP(converter, eval, blockConstruct); 1429 }, 1430 [&](const Fortran::parser::OpenMPAtomicConstruct &atomicConstruct) { 1431 genOMP(converter, eval, atomicConstruct); 1432 }, 1433 [&](const Fortran::parser::OpenMPCriticalConstruct 1434 &criticalConstruct) { 1435 genOMP(converter, eval, criticalConstruct); 1436 }, 1437 }, 1438 ompConstruct.u); 1439 } 1440 1441 void Fortran::lower::genThreadprivateOp( 1442 Fortran::lower::AbstractConverter &converter, 1443 const Fortran::lower::pft::Variable &var) { 1444 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 1445 mlir::Location currentLocation = converter.getCurrentLocation(); 1446 1447 const Fortran::semantics::Symbol &sym = var.getSymbol(); 1448 mlir::Value symThreadprivateValue; 1449 if (const Fortran::semantics::Symbol *common = 1450 Fortran::semantics::FindCommonBlockContaining(sym.GetUltimate())) { 1451 mlir::Value commonValue = converter.getSymbolAddress(*common); 1452 if (mlir::isa<mlir::omp::ThreadprivateOp>(commonValue.getDefiningOp())) { 1453 // Generate ThreadprivateOp for a common block instead of its members and 1454 // only do it once for a common block. 1455 return; 1456 } 1457 // Generate ThreadprivateOp and rebind the common block. 1458 mlir::Value commonThreadprivateValue = 1459 firOpBuilder.create<mlir::omp::ThreadprivateOp>( 1460 currentLocation, commonValue.getType(), commonValue); 1461 converter.bindSymbol(*common, commonThreadprivateValue); 1462 // Generate the threadprivate value for the common block member. 1463 symThreadprivateValue = 1464 genCommonBlockMember(converter, sym, commonThreadprivateValue); 1465 } else { 1466 mlir::Value symValue = converter.getSymbolAddress(sym); 1467 symThreadprivateValue = firOpBuilder.create<mlir::omp::ThreadprivateOp>( 1468 currentLocation, symValue.getType(), symValue); 1469 } 1470 1471 fir::ExtendedValue sexv = converter.getSymbolExtendedValue(sym); 1472 fir::ExtendedValue symThreadprivateExv = 1473 getExtendedValue(sexv, symThreadprivateValue); 1474 converter.bindSymbol(sym, symThreadprivateExv); 1475 } 1476 1477 void Fortran::lower::genOpenMPDeclarativeConstruct( 1478 Fortran::lower::AbstractConverter &converter, 1479 Fortran::lower::pft::Evaluation &eval, 1480 const Fortran::parser::OpenMPDeclarativeConstruct &ompDeclConstruct) { 1481 1482 std::visit( 1483 common::visitors{ 1484 [&](const Fortran::parser::OpenMPDeclarativeAllocate 1485 &declarativeAllocate) { 1486 TODO(converter.getCurrentLocation(), "OpenMPDeclarativeAllocate"); 1487 }, 1488 [&](const Fortran::parser::OpenMPDeclareReductionConstruct 1489 &declareReductionConstruct) { 1490 TODO(converter.getCurrentLocation(), 1491 "OpenMPDeclareReductionConstruct"); 1492 }, 1493 [&](const Fortran::parser::OpenMPDeclareSimdConstruct 1494 &declareSimdConstruct) { 1495 TODO(converter.getCurrentLocation(), "OpenMPDeclareSimdConstruct"); 1496 }, 1497 [&](const Fortran::parser::OpenMPDeclareTargetConstruct 1498 &declareTargetConstruct) { 1499 TODO(converter.getCurrentLocation(), 1500 "OpenMPDeclareTargetConstruct"); 1501 }, 1502 [&](const Fortran::parser::OpenMPThreadprivate &threadprivate) { 1503 // The directive is lowered when instantiating the variable to 1504 // support the case of threadprivate variable declared in module. 1505 }, 1506 }, 1507 ompDeclConstruct.u); 1508 } 1509 1510 // Generate an OpenMP reduction operation. This implementation finds the chain : 1511 // load reduction var -> reduction_operation -> store reduction var and replaces 1512 // it with the reduction operation. 1513 // TODO: Currently assumes it is an integer addition reduction. Generalize this 1514 // for various reduction operation types. 1515 // TODO: Generate the reduction operation during lowering instead of creating 1516 // and removing operations since this is not a robust approach. Also, removing 1517 // ops in the builder (instead of a rewriter) is probably not the best approach. 1518 void Fortran::lower::genOpenMPReduction( 1519 Fortran::lower::AbstractConverter &converter, 1520 const Fortran::parser::OmpClauseList &clauseList) { 1521 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 1522 1523 for (const auto &clause : clauseList.v) { 1524 if (const auto &reductionClause = 1525 std::get_if<Fortran::parser::OmpClause::Reduction>(&clause.u)) { 1526 const auto &redOperator{std::get<Fortran::parser::OmpReductionOperator>( 1527 reductionClause->v.t)}; 1528 const auto &objectList{ 1529 std::get<Fortran::parser::OmpObjectList>(reductionClause->v.t)}; 1530 if (auto reductionOp = 1531 std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) { 1532 const auto &intrinsicOp{ 1533 std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>( 1534 reductionOp->u)}; 1535 if (intrinsicOp != 1536 Fortran::parser::DefinedOperator::IntrinsicOperator::Add) 1537 continue; 1538 for (const auto &ompObject : objectList.v) { 1539 if (const auto *name{ 1540 Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) { 1541 if (const auto *symbol{name->symbol}) { 1542 mlir::Value symVal = converter.getSymbolAddress(*symbol); 1543 mlir::Type redType = 1544 symVal.getType().cast<fir::ReferenceType>().getEleTy(); 1545 if (!redType.isIntOrIndex()) 1546 continue; 1547 for (mlir::OpOperand &use1 : symVal.getUses()) { 1548 if (auto load = mlir::dyn_cast<fir::LoadOp>(use1.getOwner())) { 1549 mlir::Value loadVal = load.getRes(); 1550 for (mlir::OpOperand &use2 : loadVal.getUses()) { 1551 if (auto add = mlir::dyn_cast<mlir::arith::AddIOp>( 1552 use2.getOwner())) { 1553 mlir::Value addRes = add.getResult(); 1554 for (mlir::OpOperand &use3 : addRes.getUses()) { 1555 if (auto store = 1556 mlir::dyn_cast<fir::StoreOp>(use3.getOwner())) { 1557 if (store.getMemref() == symVal) { 1558 // Chain found! Now replace load->reduction->store 1559 // with the OpenMP reduction operation. 1560 mlir::OpBuilder::InsertPoint insertPtDel = 1561 firOpBuilder.saveInsertionPoint(); 1562 firOpBuilder.setInsertionPoint(add); 1563 if (add.getLhs() == loadVal) { 1564 firOpBuilder.create<mlir::omp::ReductionOp>( 1565 add.getLoc(), add.getRhs(), symVal); 1566 } else { 1567 firOpBuilder.create<mlir::omp::ReductionOp>( 1568 add.getLoc(), add.getLhs(), symVal); 1569 } 1570 store.erase(); 1571 add.erase(); 1572 load.erase(); 1573 firOpBuilder.restoreInsertionPoint(insertPtDel); 1574 } 1575 } 1576 } 1577 } 1578 } 1579 } 1580 } 1581 } 1582 } 1583 } 1584 } 1585 } 1586 } 1587 } 1588