1 //===-- OpenMP.cpp -- Open MP directive lowering --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "flang/Lower/OpenMP.h" 14 #include "flang/Common/idioms.h" 15 #include "flang/Lower/Bridge.h" 16 #include "flang/Lower/ConvertExpr.h" 17 #include "flang/Lower/PFTBuilder.h" 18 #include "flang/Lower/StatementContext.h" 19 #include "flang/Optimizer/Builder/BoxValue.h" 20 #include "flang/Optimizer/Builder/FIRBuilder.h" 21 #include "flang/Optimizer/Builder/Todo.h" 22 #include "flang/Parser/parse-tree.h" 23 #include "flang/Semantics/tools.h" 24 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 25 #include "mlir/Dialect/SCF/IR/SCF.h" 26 #include "llvm/Frontend/OpenMP/OMPConstants.h" 27 28 using namespace mlir; 29 30 int64_t Fortran::lower::getCollapseValue( 31 const Fortran::parser::OmpClauseList &clauseList) { 32 for (const auto &clause : clauseList.v) { 33 if (const auto &collapseClause = 34 std::get_if<Fortran::parser::OmpClause::Collapse>(&clause.u)) { 35 const auto *expr = Fortran::semantics::GetExpr(collapseClause->v); 36 return Fortran::evaluate::ToInt64(*expr).value(); 37 } 38 } 39 return 1; 40 } 41 42 static const Fortran::parser::Name * 43 getDesignatorNameIfDataRef(const Fortran::parser::Designator &designator) { 44 const auto *dataRef = std::get_if<Fortran::parser::DataRef>(&designator.u); 45 return dataRef ? std::get_if<Fortran::parser::Name>(&dataRef->u) : nullptr; 46 } 47 48 static Fortran::semantics::Symbol * 49 getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject) { 50 Fortran::semantics::Symbol *sym = nullptr; 51 std::visit(Fortran::common::visitors{ 52 [&](const Fortran::parser::Designator &designator) { 53 if (const Fortran::parser::Name *name = 54 getDesignatorNameIfDataRef(designator)) { 55 sym = name->symbol; 56 } 57 }, 58 [&](const Fortran::parser::Name &name) { sym = name.symbol; }}, 59 ompObject.u); 60 return sym; 61 } 62 63 template <typename T> 64 static void createPrivateVarSyms(Fortran::lower::AbstractConverter &converter, 65 const T *clause, 66 Block *lastPrivBlock = nullptr) { 67 const Fortran::parser::OmpObjectList &ompObjectList = clause->v; 68 for (const Fortran::parser::OmpObject &ompObject : ompObjectList.v) { 69 Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); 70 // Privatization for symbols which are pre-determined (like loop index 71 // variables) happen separately, for everything else privatize here. 72 if (sym->test(Fortran::semantics::Symbol::Flag::OmpPreDetermined)) 73 continue; 74 bool success = converter.createHostAssociateVarClone(*sym); 75 (void)success; 76 assert(success && "Privatization failed due to existing binding"); 77 if constexpr (std::is_same_v<T, Fortran::parser::OmpClause::Firstprivate>) { 78 converter.copyHostAssociateVar(*sym); 79 } else if constexpr (std::is_same_v< 80 T, Fortran::parser::OmpClause::Lastprivate>) { 81 converter.copyHostAssociateVar(*sym, lastPrivBlock); 82 } 83 } 84 } 85 86 template <typename Op> 87 static bool privatizeVars(Op &op, Fortran::lower::AbstractConverter &converter, 88 const Fortran::parser::OmpClauseList &opClauseList) { 89 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 90 auto insPt = firOpBuilder.saveInsertionPoint(); 91 firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock()); 92 bool hasFirstPrivateOp = false; 93 bool hasLastPrivateOp = false; 94 Block *lastPrivBlock = nullptr; 95 // We need just one ICmpOp for multiple LastPrivate clauses. 96 mlir::arith::CmpIOp cmpOp; 97 98 for (const Fortran::parser::OmpClause &clause : opClauseList.v) { 99 if (const auto &privateClause = 100 std::get_if<Fortran::parser::OmpClause::Private>(&clause.u)) { 101 createPrivateVarSyms(converter, privateClause); 102 } else if (const auto &firstPrivateClause = 103 std::get_if<Fortran::parser::OmpClause::Firstprivate>( 104 &clause.u)) { 105 createPrivateVarSyms(converter, firstPrivateClause); 106 hasFirstPrivateOp = true; 107 } else if (const auto &lastPrivateClause = 108 std::get_if<Fortran::parser::OmpClause::Lastprivate>( 109 &clause.u)) { 110 // TODO: Add lastprivate support for sections construct, simd construct 111 if (std::is_same_v<Op, omp::WsLoopOp>) { 112 omp::WsLoopOp *wsLoopOp = dyn_cast<omp::WsLoopOp>(&op); 113 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 114 auto insPt = firOpBuilder.saveInsertionPoint(); 115 116 // Our goal here is to introduce the following control flow 117 // just before exiting the worksharing loop. 118 // Say our wsloop is as follows: 119 // 120 // omp.wsloop { 121 // ... 122 // store 123 // omp.yield 124 // } 125 // 126 // We want to convert it to the following: 127 // 128 // omp.wsloop { 129 // ... 130 // store 131 // %cmp = llvm.icmp "eq" %iv %ub 132 // scf.if %cmp { 133 // ^%lpv_update_blk: 134 // } 135 // omp.yield 136 // } 137 138 Operation *lastOper = wsLoopOp->region().back().getTerminator(); 139 140 firOpBuilder.setInsertionPoint(lastOper); 141 142 // TODO: The following will not work when there is collapse present. 143 // Have to modify this in future. 144 for (const Fortran::parser::OmpClause &clause : opClauseList.v) 145 if (const auto &collapseClause = 146 std::get_if<Fortran::parser::OmpClause::Collapse>(&clause.u)) 147 TODO(converter.getCurrentLocation(), 148 "Collapse clause with lastprivate"); 149 // Only generate the compare once in presence of multiple LastPrivate 150 // clauses 151 if (!hasLastPrivateOp) { 152 cmpOp = firOpBuilder.create<mlir::arith::CmpIOp>( 153 wsLoopOp->getLoc(), mlir::arith::CmpIPredicate::eq, 154 wsLoopOp->getRegion().front().getArguments()[0], 155 wsLoopOp->upperBound()[0]); 156 } 157 mlir::scf::IfOp ifOp = firOpBuilder.create<mlir::scf::IfOp>( 158 wsLoopOp->getLoc(), cmpOp, /*else*/ false); 159 160 firOpBuilder.restoreInsertionPoint(insPt); 161 createPrivateVarSyms(converter, lastPrivateClause, 162 &(ifOp.getThenRegion().front())); 163 } else { 164 TODO(converter.getCurrentLocation(), 165 "lastprivate clause in constructs other than work-share loop"); 166 } 167 hasLastPrivateOp = true; 168 } 169 } 170 if (hasFirstPrivateOp) 171 firOpBuilder.create<mlir::omp::BarrierOp>(converter.getCurrentLocation()); 172 firOpBuilder.restoreInsertionPoint(insPt); 173 return hasLastPrivateOp; 174 } 175 176 /// The COMMON block is a global structure. \p commonValue is the base address 177 /// of the the COMMON block. As the offset from the symbol \p sym, generate the 178 /// COMMON block member value (commonValue + offset) for the symbol. 179 /// FIXME: Share the code with `instantiateCommon` in ConvertVariable.cpp. 180 static mlir::Value 181 genCommonBlockMember(Fortran::lower::AbstractConverter &converter, 182 const Fortran::semantics::Symbol &sym, 183 mlir::Value commonValue) { 184 auto &firOpBuilder = converter.getFirOpBuilder(); 185 mlir::Location currentLocation = converter.getCurrentLocation(); 186 mlir::IntegerType i8Ty = firOpBuilder.getIntegerType(8); 187 mlir::Type i8Ptr = firOpBuilder.getRefType(i8Ty); 188 mlir::Type seqTy = firOpBuilder.getRefType(firOpBuilder.getVarLenSeqTy(i8Ty)); 189 mlir::Value base = 190 firOpBuilder.createConvert(currentLocation, seqTy, commonValue); 191 std::size_t byteOffset = sym.GetUltimate().offset(); 192 mlir::Value offs = firOpBuilder.createIntegerConstant( 193 currentLocation, firOpBuilder.getIndexType(), byteOffset); 194 mlir::Value varAddr = firOpBuilder.create<fir::CoordinateOp>( 195 currentLocation, i8Ptr, base, mlir::ValueRange{offs}); 196 mlir::Type symType = converter.genType(sym); 197 return firOpBuilder.createConvert(currentLocation, 198 firOpBuilder.getRefType(symType), varAddr); 199 } 200 201 // Get the extended value for \p val by extracting additional variable 202 // information from \p base. 203 static fir::ExtendedValue getExtendedValue(fir::ExtendedValue base, 204 mlir::Value val) { 205 return base.match( 206 [&](const fir::MutableBoxValue &box) -> fir::ExtendedValue { 207 return fir::MutableBoxValue(val, box.nonDeferredLenParams(), {}); 208 }, 209 [&](const auto &) -> fir::ExtendedValue { 210 return fir::substBase(base, val); 211 }); 212 } 213 214 static void threadPrivatizeVars(Fortran::lower::AbstractConverter &converter, 215 Fortran::lower::pft::Evaluation &eval) { 216 auto &firOpBuilder = converter.getFirOpBuilder(); 217 mlir::Location currentLocation = converter.getCurrentLocation(); 218 auto insPt = firOpBuilder.saveInsertionPoint(); 219 firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock()); 220 221 // Get the original ThreadprivateOp corresponding to the symbol and use the 222 // symbol value from that opeartion to create one ThreadprivateOp copy 223 // operation inside the parallel region. 224 auto genThreadprivateOp = [&](Fortran::lower::SymbolRef sym) -> mlir::Value { 225 mlir::Value symOriThreadprivateValue = converter.getSymbolAddress(sym); 226 mlir::Operation *op = symOriThreadprivateValue.getDefiningOp(); 227 assert(mlir::isa<mlir::omp::ThreadprivateOp>(op) && 228 "The threadprivate operation not created"); 229 mlir::Value symValue = 230 mlir::dyn_cast<mlir::omp::ThreadprivateOp>(op).sym_addr(); 231 return firOpBuilder.create<mlir::omp::ThreadprivateOp>( 232 currentLocation, symValue.getType(), symValue); 233 }; 234 235 llvm::SetVector<const Fortran::semantics::Symbol *> threadprivateSyms; 236 converter.collectSymbolSet(eval, threadprivateSyms, 237 Fortran::semantics::Symbol::Flag::OmpThreadprivate, 238 /*isUltimateSymbol=*/false); 239 std::set<Fortran::semantics::SourceName> threadprivateSymNames; 240 241 // For a COMMON block, the ThreadprivateOp is generated for itself instead of 242 // its members, so only bind the value of the new copied ThreadprivateOp 243 // inside the parallel region to the common block symbol only once for 244 // multiple members in one COMMON block. 245 llvm::SetVector<const Fortran::semantics::Symbol *> commonSyms; 246 for (std::size_t i = 0; i < threadprivateSyms.size(); i++) { 247 auto sym = threadprivateSyms[i]; 248 mlir::Value symThreadprivateValue; 249 // The variable may be used more than once, and each reference has one 250 // symbol with the same name. Only do once for references of one variable. 251 if (threadprivateSymNames.find(sym->name()) != threadprivateSymNames.end()) 252 continue; 253 threadprivateSymNames.insert(sym->name()); 254 if (const Fortran::semantics::Symbol *common = 255 Fortran::semantics::FindCommonBlockContaining(sym->GetUltimate())) { 256 mlir::Value commonThreadprivateValue; 257 if (commonSyms.contains(common)) { 258 commonThreadprivateValue = converter.getSymbolAddress(*common); 259 } else { 260 commonThreadprivateValue = genThreadprivateOp(*common); 261 converter.bindSymbol(*common, commonThreadprivateValue); 262 commonSyms.insert(common); 263 } 264 symThreadprivateValue = 265 genCommonBlockMember(converter, *sym, commonThreadprivateValue); 266 } else { 267 symThreadprivateValue = genThreadprivateOp(*sym); 268 } 269 270 fir::ExtendedValue sexv = converter.getSymbolExtendedValue(*sym); 271 fir::ExtendedValue symThreadprivateExv = 272 getExtendedValue(sexv, symThreadprivateValue); 273 converter.bindSymbol(*sym, symThreadprivateExv); 274 } 275 276 firOpBuilder.restoreInsertionPoint(insPt); 277 } 278 279 static void 280 genCopyinClause(Fortran::lower::AbstractConverter &converter, 281 const Fortran::parser::OmpClauseList &opClauseList) { 282 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 283 mlir::OpBuilder::InsertPoint insPt = firOpBuilder.saveInsertionPoint(); 284 firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock()); 285 bool hasCopyin = false; 286 for (const Fortran::parser::OmpClause &clause : opClauseList.v) { 287 if (const auto ©inClause = 288 std::get_if<Fortran::parser::OmpClause::Copyin>(&clause.u)) { 289 hasCopyin = true; 290 const Fortran::parser::OmpObjectList &ompObjectList = copyinClause->v; 291 for (const Fortran::parser::OmpObject &ompObject : ompObjectList.v) { 292 Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); 293 if (sym->has<Fortran::semantics::CommonBlockDetails>()) 294 TODO(converter.getCurrentLocation(), "common block in Copyin clause"); 295 if (Fortran::semantics::IsAllocatableOrPointer(sym->GetUltimate())) 296 TODO(converter.getCurrentLocation(), 297 "pointer or allocatable variables in Copyin clause"); 298 assert(sym->has<Fortran::semantics::HostAssocDetails>() && 299 "No host-association found"); 300 converter.copyHostAssociateVar(*sym); 301 } 302 } 303 } 304 // [OMP 5.0, 2.19.6.1] The copy is done after the team is formed and prior to 305 // the execution of the associated structured block. Emit implicit barrier to 306 // synchronize threads and avoid data races on propagation master's thread 307 // values of threadprivate variables to local instances of that variables of 308 // all other implicit threads. 309 if (hasCopyin) 310 firOpBuilder.create<mlir::omp::BarrierOp>(converter.getCurrentLocation()); 311 firOpBuilder.restoreInsertionPoint(insPt); 312 } 313 314 static void genObjectList(const Fortran::parser::OmpObjectList &objectList, 315 Fortran::lower::AbstractConverter &converter, 316 llvm::SmallVectorImpl<Value> &operands) { 317 auto addOperands = [&](Fortran::lower::SymbolRef sym) { 318 const mlir::Value variable = converter.getSymbolAddress(sym); 319 if (variable) { 320 operands.push_back(variable); 321 } else { 322 if (const auto *details = 323 sym->detailsIf<Fortran::semantics::HostAssocDetails>()) { 324 operands.push_back(converter.getSymbolAddress(details->symbol())); 325 converter.copySymbolBinding(details->symbol(), sym); 326 } 327 } 328 }; 329 for (const Fortran::parser::OmpObject &ompObject : objectList.v) { 330 Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); 331 addOperands(*sym); 332 } 333 } 334 335 static mlir::Type getLoopVarType(Fortran::lower::AbstractConverter &converter, 336 std::size_t loopVarTypeSize) { 337 // OpenMP runtime requires 32-bit or 64-bit loop variables. 338 loopVarTypeSize = loopVarTypeSize * 8; 339 if (loopVarTypeSize < 32) { 340 loopVarTypeSize = 32; 341 } else if (loopVarTypeSize > 64) { 342 loopVarTypeSize = 64; 343 mlir::emitWarning(converter.getCurrentLocation(), 344 "OpenMP loop iteration variable cannot have more than 64 " 345 "bits size and will be narrowed into 64 bits."); 346 } 347 assert((loopVarTypeSize == 32 || loopVarTypeSize == 64) && 348 "OpenMP loop iteration variable size must be transformed into 32-bit " 349 "or 64-bit"); 350 return converter.getFirOpBuilder().getIntegerType(loopVarTypeSize); 351 } 352 353 /// Create empty blocks for the current region. 354 /// These blocks replace blocks parented to an enclosing region. 355 void createEmptyRegionBlocks( 356 fir::FirOpBuilder &firOpBuilder, 357 std::list<Fortran::lower::pft::Evaluation> &evaluationList) { 358 auto *region = &firOpBuilder.getRegion(); 359 for (auto &eval : evaluationList) { 360 if (eval.block) { 361 if (eval.block->empty()) { 362 eval.block->erase(); 363 eval.block = firOpBuilder.createBlock(region); 364 } else { 365 [[maybe_unused]] auto &terminatorOp = eval.block->back(); 366 assert((mlir::isa<mlir::omp::TerminatorOp>(terminatorOp) || 367 mlir::isa<mlir::omp::YieldOp>(terminatorOp)) && 368 "expected terminator op"); 369 } 370 } 371 if (!eval.isDirective() && eval.hasNestedEvaluations()) 372 createEmptyRegionBlocks(firOpBuilder, eval.getNestedEvaluations()); 373 } 374 } 375 376 void resetBeforeTerminator(fir::FirOpBuilder &firOpBuilder, 377 mlir::Operation *storeOp, mlir::Block &block) { 378 if (storeOp) 379 firOpBuilder.setInsertionPointAfter(storeOp); 380 else 381 firOpBuilder.setInsertionPointToStart(&block); 382 } 383 384 /// Create the body (block) for an OpenMP Operation. 385 /// 386 /// \param [in] op - the operation the body belongs to. 387 /// \param [inout] converter - converter to use for the clauses. 388 /// \param [in] loc - location in source code. 389 /// \param [in] eval - current PFT node/evaluation. 390 /// \oaran [in] clauses - list of clauses to process. 391 /// \param [in] args - block arguments (induction variable[s]) for the 392 //// region. 393 /// \param [in] outerCombined - is this an outer operation - prevents 394 /// privatization. 395 template <typename Op> 396 static void 397 createBodyOfOp(Op &op, Fortran::lower::AbstractConverter &converter, 398 mlir::Location &loc, Fortran::lower::pft::Evaluation &eval, 399 const Fortran::parser::OmpClauseList *clauses = nullptr, 400 const SmallVector<const Fortran::semantics::Symbol *> &args = {}, 401 bool outerCombined = false) { 402 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 403 // If an argument for the region is provided then create the block with that 404 // argument. Also update the symbol's address with the mlir argument value. 405 // e.g. For loops the argument is the induction variable. And all further 406 // uses of the induction variable should use this mlir value. 407 mlir::Operation *storeOp = nullptr; 408 if (args.size()) { 409 std::size_t loopVarTypeSize = 0; 410 for (const Fortran::semantics::Symbol *arg : args) 411 loopVarTypeSize = std::max(loopVarTypeSize, arg->GetUltimate().size()); 412 mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize); 413 SmallVector<Type> tiv; 414 SmallVector<Location> locs; 415 for (int i = 0; i < (int)args.size(); i++) { 416 tiv.push_back(loopVarType); 417 locs.push_back(loc); 418 } 419 firOpBuilder.createBlock(&op.getRegion(), {}, tiv, locs); 420 int argIndex = 0; 421 // The argument is not currently in memory, so make a temporary for the 422 // argument, and store it there, then bind that location to the argument. 423 for (const Fortran::semantics::Symbol *arg : args) { 424 mlir::Value val = 425 fir::getBase(op.getRegion().front().getArgument(argIndex)); 426 mlir::Value temp = firOpBuilder.createTemporary( 427 loc, loopVarType, 428 llvm::ArrayRef<mlir::NamedAttribute>{ 429 Fortran::lower::getAdaptToByRefAttr(firOpBuilder)}); 430 storeOp = firOpBuilder.create<fir::StoreOp>(loc, val, temp); 431 converter.bindSymbol(*arg, temp); 432 argIndex++; 433 } 434 } else { 435 firOpBuilder.createBlock(&op.getRegion()); 436 } 437 // Set the insert for the terminator operation to go at the end of the 438 // block - this is either empty or the block with the stores above, 439 // the end of the block works for both. 440 mlir::Block &block = op.getRegion().back(); 441 firOpBuilder.setInsertionPointToEnd(&block); 442 443 // If it is an unstructured region and is not the outer region of a combined 444 // construct, create empty blocks for all evaluations. 445 if (eval.lowerAsUnstructured() && !outerCombined) 446 createEmptyRegionBlocks(firOpBuilder, eval.getNestedEvaluations()); 447 448 // Insert the terminator. 449 if constexpr (std::is_same_v<Op, omp::WsLoopOp> || 450 std::is_same_v<Op, omp::SimdLoopOp>) { 451 mlir::ValueRange results; 452 firOpBuilder.create<mlir::omp::YieldOp>(loc, results); 453 } else { 454 firOpBuilder.create<mlir::omp::TerminatorOp>(loc); 455 } 456 457 // Reset the insert point to before the terminator. 458 resetBeforeTerminator(firOpBuilder, storeOp, block); 459 460 // Handle privatization. Do not privatize if this is the outer operation. 461 if (clauses && !outerCombined) { 462 bool lastPrivateOp = privatizeVars(op, converter, *clauses); 463 // LastPrivatization, due to introduction of 464 // new control flow, changes the insertion point, 465 // thus restore it. 466 // TODO: Clean up later a bit to avoid this many sets and resets. 467 if (lastPrivateOp) 468 resetBeforeTerminator(firOpBuilder, storeOp, block); 469 } 470 471 if constexpr (std::is_same_v<Op, omp::ParallelOp>) { 472 threadPrivatizeVars(converter, eval); 473 if (clauses) 474 genCopyinClause(converter, *clauses); 475 } 476 } 477 478 static void genOMP(Fortran::lower::AbstractConverter &converter, 479 Fortran::lower::pft::Evaluation &eval, 480 const Fortran::parser::OpenMPSimpleStandaloneConstruct 481 &simpleStandaloneConstruct) { 482 const auto &directive = 483 std::get<Fortran::parser::OmpSimpleStandaloneDirective>( 484 simpleStandaloneConstruct.t); 485 switch (directive.v) { 486 default: 487 break; 488 case llvm::omp::Directive::OMPD_barrier: 489 converter.getFirOpBuilder().create<mlir::omp::BarrierOp>( 490 converter.getCurrentLocation()); 491 break; 492 case llvm::omp::Directive::OMPD_taskwait: 493 converter.getFirOpBuilder().create<mlir::omp::TaskwaitOp>( 494 converter.getCurrentLocation()); 495 break; 496 case llvm::omp::Directive::OMPD_taskyield: 497 converter.getFirOpBuilder().create<mlir::omp::TaskyieldOp>( 498 converter.getCurrentLocation()); 499 break; 500 case llvm::omp::Directive::OMPD_target_enter_data: 501 TODO(converter.getCurrentLocation(), "OMPD_target_enter_data"); 502 case llvm::omp::Directive::OMPD_target_exit_data: 503 TODO(converter.getCurrentLocation(), "OMPD_target_exit_data"); 504 case llvm::omp::Directive::OMPD_target_update: 505 TODO(converter.getCurrentLocation(), "OMPD_target_update"); 506 case llvm::omp::Directive::OMPD_ordered: 507 TODO(converter.getCurrentLocation(), "OMPD_ordered"); 508 } 509 } 510 511 static void 512 genAllocateClause(Fortran::lower::AbstractConverter &converter, 513 const Fortran::parser::OmpAllocateClause &ompAllocateClause, 514 SmallVector<Value> &allocatorOperands, 515 SmallVector<Value> &allocateOperands) { 516 auto &firOpBuilder = converter.getFirOpBuilder(); 517 auto currentLocation = converter.getCurrentLocation(); 518 Fortran::lower::StatementContext stmtCtx; 519 520 mlir::Value allocatorOperand; 521 const Fortran::parser::OmpObjectList &ompObjectList = 522 std::get<Fortran::parser::OmpObjectList>(ompAllocateClause.t); 523 const auto &allocatorValue = 524 std::get<std::optional<Fortran::parser::OmpAllocateClause::Allocator>>( 525 ompAllocateClause.t); 526 // Check if allocate clause has allocator specified. If so, add it 527 // to list of allocators, otherwise, add default allocator to 528 // list of allocators. 529 if (allocatorValue) { 530 allocatorOperand = fir::getBase(converter.genExprValue( 531 *Fortran::semantics::GetExpr(allocatorValue->v), stmtCtx)); 532 allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(), 533 allocatorOperand); 534 } else { 535 allocatorOperand = firOpBuilder.createIntegerConstant( 536 currentLocation, firOpBuilder.getI32Type(), 1); 537 allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(), 538 allocatorOperand); 539 } 540 genObjectList(ompObjectList, converter, allocateOperands); 541 } 542 543 static void 544 genOMP(Fortran::lower::AbstractConverter &converter, 545 Fortran::lower::pft::Evaluation &eval, 546 const Fortran::parser::OpenMPStandaloneConstruct &standaloneConstruct) { 547 std::visit( 548 Fortran::common::visitors{ 549 [&](const Fortran::parser::OpenMPSimpleStandaloneConstruct 550 &simpleStandaloneConstruct) { 551 genOMP(converter, eval, simpleStandaloneConstruct); 552 }, 553 [&](const Fortran::parser::OpenMPFlushConstruct &flushConstruct) { 554 SmallVector<Value, 4> operandRange; 555 if (const auto &ompObjectList = 556 std::get<std::optional<Fortran::parser::OmpObjectList>>( 557 flushConstruct.t)) 558 genObjectList(*ompObjectList, converter, operandRange); 559 const auto &memOrderClause = std::get<std::optional< 560 std::list<Fortran::parser::OmpMemoryOrderClause>>>( 561 flushConstruct.t); 562 if (memOrderClause.has_value() && memOrderClause->size() > 0) 563 TODO(converter.getCurrentLocation(), 564 "Handle OmpMemoryOrderClause"); 565 converter.getFirOpBuilder().create<mlir::omp::FlushOp>( 566 converter.getCurrentLocation(), operandRange); 567 }, 568 [&](const Fortran::parser::OpenMPCancelConstruct &cancelConstruct) { 569 TODO(converter.getCurrentLocation(), "OpenMPCancelConstruct"); 570 }, 571 [&](const Fortran::parser::OpenMPCancellationPointConstruct 572 &cancellationPointConstruct) { 573 TODO(converter.getCurrentLocation(), "OpenMPCancelConstruct"); 574 }, 575 }, 576 standaloneConstruct.u); 577 } 578 579 static omp::ClauseProcBindKindAttr genProcBindKindAttr( 580 fir::FirOpBuilder &firOpBuilder, 581 const Fortran::parser::OmpClause::ProcBind *procBindClause) { 582 omp::ClauseProcBindKind pbKind; 583 switch (procBindClause->v.v) { 584 case Fortran::parser::OmpProcBindClause::Type::Master: 585 pbKind = omp::ClauseProcBindKind::Master; 586 break; 587 case Fortran::parser::OmpProcBindClause::Type::Close: 588 pbKind = omp::ClauseProcBindKind::Close; 589 break; 590 case Fortran::parser::OmpProcBindClause::Type::Spread: 591 pbKind = omp::ClauseProcBindKind::Spread; 592 break; 593 case Fortran::parser::OmpProcBindClause::Type::Primary: 594 pbKind = omp::ClauseProcBindKind::Primary; 595 break; 596 } 597 return omp::ClauseProcBindKindAttr::get(firOpBuilder.getContext(), pbKind); 598 } 599 600 static mlir::Value 601 getIfClauseOperand(Fortran::lower::AbstractConverter &converter, 602 Fortran::lower::StatementContext &stmtCtx, 603 const Fortran::parser::OmpClause::If *ifClause) { 604 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 605 mlir::Location currentLocation = converter.getCurrentLocation(); 606 auto &expr = std::get<Fortran::parser::ScalarLogicalExpr>(ifClause->v.t); 607 mlir::Value ifVal = fir::getBase( 608 converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx)); 609 return firOpBuilder.createConvert(currentLocation, firOpBuilder.getI1Type(), 610 ifVal); 611 } 612 613 /* When parallel is used in a combined construct, then use this function to 614 * create the parallel operation. It handles the parallel specific clauses 615 * and leaves the rest for handling at the inner operations. 616 * TODO: Refactor clause handling 617 */ 618 template <typename Directive> 619 static void 620 createCombinedParallelOp(Fortran::lower::AbstractConverter &converter, 621 Fortran::lower::pft::Evaluation &eval, 622 const Directive &directive) { 623 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 624 mlir::Location currentLocation = converter.getCurrentLocation(); 625 Fortran::lower::StatementContext stmtCtx; 626 llvm::ArrayRef<mlir::Type> argTy; 627 mlir::Value ifClauseOperand, numThreadsClauseOperand; 628 SmallVector<Value> allocatorOperands, allocateOperands; 629 mlir::omp::ClauseProcBindKindAttr procBindKindAttr; 630 const auto &opClauseList = 631 std::get<Fortran::parser::OmpClauseList>(directive.t); 632 // TODO: Handle the following clauses 633 // 1. default 634 // Note: rest of the clauses are handled when the inner operation is created 635 for (const Fortran::parser::OmpClause &clause : opClauseList.v) { 636 if (const auto &ifClause = 637 std::get_if<Fortran::parser::OmpClause::If>(&clause.u)) { 638 ifClauseOperand = getIfClauseOperand(converter, stmtCtx, ifClause); 639 } else if (const auto &numThreadsClause = 640 std::get_if<Fortran::parser::OmpClause::NumThreads>( 641 &clause.u)) { 642 numThreadsClauseOperand = fir::getBase(converter.genExprValue( 643 *Fortran::semantics::GetExpr(numThreadsClause->v), stmtCtx)); 644 } else if (const auto &procBindClause = 645 std::get_if<Fortran::parser::OmpClause::ProcBind>( 646 &clause.u)) { 647 procBindKindAttr = genProcBindKindAttr(firOpBuilder, procBindClause); 648 } 649 } 650 // Create and insert the operation. 651 auto parallelOp = firOpBuilder.create<mlir::omp::ParallelOp>( 652 currentLocation, argTy, ifClauseOperand, numThreadsClauseOperand, 653 allocateOperands, allocatorOperands, /*reduction_vars=*/ValueRange(), 654 /*reductions=*/nullptr, procBindKindAttr); 655 656 createBodyOfOp<omp::ParallelOp>(parallelOp, converter, currentLocation, eval, 657 &opClauseList, /*iv=*/{}, 658 /*isCombined=*/true); 659 } 660 661 static void 662 genOMP(Fortran::lower::AbstractConverter &converter, 663 Fortran::lower::pft::Evaluation &eval, 664 const Fortran::parser::OpenMPBlockConstruct &blockConstruct) { 665 const auto &beginBlockDirective = 666 std::get<Fortran::parser::OmpBeginBlockDirective>(blockConstruct.t); 667 const auto &blockDirective = 668 std::get<Fortran::parser::OmpBlockDirective>(beginBlockDirective.t); 669 const auto &endBlockDirective = 670 std::get<Fortran::parser::OmpEndBlockDirective>(blockConstruct.t); 671 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 672 mlir::Location currentLocation = converter.getCurrentLocation(); 673 674 Fortran::lower::StatementContext stmtCtx; 675 llvm::ArrayRef<mlir::Type> argTy; 676 mlir::Value ifClauseOperand, numThreadsClauseOperand, finalClauseOperand, 677 priorityClauseOperand; 678 mlir::omp::ClauseProcBindKindAttr procBindKindAttr; 679 SmallVector<Value> allocateOperands, allocatorOperands; 680 mlir::UnitAttr nowaitAttr, untiedAttr, mergeableAttr; 681 682 const auto &opClauseList = 683 std::get<Fortran::parser::OmpClauseList>(beginBlockDirective.t); 684 for (const auto &clause : opClauseList.v) { 685 if (const auto &ifClause = 686 std::get_if<Fortran::parser::OmpClause::If>(&clause.u)) { 687 ifClauseOperand = getIfClauseOperand(converter, stmtCtx, ifClause); 688 } else if (const auto &numThreadsClause = 689 std::get_if<Fortran::parser::OmpClause::NumThreads>( 690 &clause.u)) { 691 // OMPIRBuilder expects `NUM_THREAD` clause as a `Value`. 692 numThreadsClauseOperand = fir::getBase(converter.genExprValue( 693 *Fortran::semantics::GetExpr(numThreadsClause->v), stmtCtx)); 694 } else if (const auto &procBindClause = 695 std::get_if<Fortran::parser::OmpClause::ProcBind>( 696 &clause.u)) { 697 procBindKindAttr = genProcBindKindAttr(firOpBuilder, procBindClause); 698 } else if (const auto &allocateClause = 699 std::get_if<Fortran::parser::OmpClause::Allocate>( 700 &clause.u)) { 701 genAllocateClause(converter, allocateClause->v, allocatorOperands, 702 allocateOperands); 703 } else if (std::get_if<Fortran::parser::OmpClause::Private>(&clause.u) || 704 std::get_if<Fortran::parser::OmpClause::Firstprivate>( 705 &clause.u) || 706 std::get_if<Fortran::parser::OmpClause::Copyin>(&clause.u)) { 707 // Privatisation and copyin clauses are handled elsewhere. 708 continue; 709 } else if (std::get_if<Fortran::parser::OmpClause::Shared>(&clause.u)) { 710 // Shared is the default behavior in the IR, so no handling is required. 711 continue; 712 } else if (const auto &defaultClause = 713 std::get_if<Fortran::parser::OmpClause::Default>( 714 &clause.u)) { 715 if ((defaultClause->v.v == 716 Fortran::parser::OmpDefaultClause::Type::Shared) || 717 (defaultClause->v.v == 718 Fortran::parser::OmpDefaultClause::Type::None)) { 719 // Default clause with shared or none do not require any handling since 720 // Shared is the default behavior in the IR and None is only required 721 // for semantic checks. 722 continue; 723 } 724 } else if (std::get_if<Fortran::parser::OmpClause::Threads>(&clause.u)) { 725 // Nothing needs to be done for threads clause. 726 continue; 727 } else if (const auto &finalClause = 728 std::get_if<Fortran::parser::OmpClause::Final>(&clause.u)) { 729 mlir::Value finalVal = fir::getBase(converter.genExprValue( 730 *Fortran::semantics::GetExpr(finalClause->v), stmtCtx)); 731 finalClauseOperand = firOpBuilder.createConvert( 732 currentLocation, firOpBuilder.getI1Type(), finalVal); 733 } else if (std::get_if<Fortran::parser::OmpClause::Untied>(&clause.u)) { 734 untiedAttr = firOpBuilder.getUnitAttr(); 735 } else if (std::get_if<Fortran::parser::OmpClause::Mergeable>(&clause.u)) { 736 mergeableAttr = firOpBuilder.getUnitAttr(); 737 } else if (const auto &priorityClause = 738 std::get_if<Fortran::parser::OmpClause::Priority>( 739 &clause.u)) { 740 priorityClauseOperand = fir::getBase(converter.genExprValue( 741 *Fortran::semantics::GetExpr(priorityClause->v), stmtCtx)); 742 } else { 743 TODO(currentLocation, "OpenMP Block construct clauses"); 744 } 745 } 746 747 for (const auto &clause : 748 std::get<Fortran::parser::OmpClauseList>(endBlockDirective.t).v) { 749 if (std::get_if<Fortran::parser::OmpClause::Nowait>(&clause.u)) 750 nowaitAttr = firOpBuilder.getUnitAttr(); 751 } 752 753 if (blockDirective.v == llvm::omp::OMPD_parallel) { 754 // Create and insert the operation. 755 auto parallelOp = firOpBuilder.create<mlir::omp::ParallelOp>( 756 currentLocation, argTy, ifClauseOperand, numThreadsClauseOperand, 757 allocateOperands, allocatorOperands, /*reduction_vars=*/ValueRange(), 758 /*reductions=*/nullptr, procBindKindAttr); 759 createBodyOfOp<omp::ParallelOp>(parallelOp, converter, currentLocation, 760 eval, &opClauseList); 761 } else if (blockDirective.v == llvm::omp::OMPD_master) { 762 auto masterOp = 763 firOpBuilder.create<mlir::omp::MasterOp>(currentLocation, argTy); 764 createBodyOfOp<omp::MasterOp>(masterOp, converter, currentLocation, eval); 765 } else if (blockDirective.v == llvm::omp::OMPD_single) { 766 auto singleOp = firOpBuilder.create<mlir::omp::SingleOp>( 767 currentLocation, allocateOperands, allocatorOperands, nowaitAttr); 768 createBodyOfOp<omp::SingleOp>(singleOp, converter, currentLocation, eval); 769 } else if (blockDirective.v == llvm::omp::OMPD_ordered) { 770 auto orderedOp = firOpBuilder.create<mlir::omp::OrderedRegionOp>( 771 currentLocation, /*simd=*/nullptr); 772 createBodyOfOp<omp::OrderedRegionOp>(orderedOp, converter, currentLocation, 773 eval); 774 } else if (blockDirective.v == llvm::omp::OMPD_task) { 775 auto taskOp = firOpBuilder.create<mlir::omp::TaskOp>( 776 currentLocation, ifClauseOperand, finalClauseOperand, untiedAttr, 777 mergeableAttr, /*in_reduction_vars=*/ValueRange(), 778 /*in_reductions=*/nullptr, priorityClauseOperand, allocateOperands, 779 allocatorOperands); 780 createBodyOfOp(taskOp, converter, currentLocation, eval, &opClauseList); 781 } else { 782 TODO(converter.getCurrentLocation(), "Unhandled block directive"); 783 } 784 } 785 786 /// Creates an OpenMP reduction declaration and inserts it into the provided 787 /// symbol table. The declaration has a constant initializer with the neutral 788 /// value `initValue`, and the reduction combiner carried over from `reduce`. 789 /// TODO: Generalize this for non-integer types, add atomic region. 790 static omp::ReductionDeclareOp createReductionDecl(fir::FirOpBuilder &builder, 791 llvm::StringRef name, 792 mlir::Type type, 793 mlir::Location loc) { 794 OpBuilder::InsertionGuard guard(builder); 795 mlir::ModuleOp module = builder.getModule(); 796 mlir::OpBuilder modBuilder(module.getBodyRegion()); 797 auto decl = module.lookupSymbol<mlir::omp::ReductionDeclareOp>(name); 798 if (!decl) 799 decl = modBuilder.create<omp::ReductionDeclareOp>(loc, name, type); 800 else 801 return decl; 802 803 builder.createBlock(&decl.initializerRegion(), decl.initializerRegion().end(), 804 {type}, {loc}); 805 builder.setInsertionPointToEnd(&decl.initializerRegion().back()); 806 Value init = builder.create<mlir::arith::ConstantOp>( 807 loc, type, builder.getIntegerAttr(type, 0)); 808 builder.create<omp::YieldOp>(loc, init); 809 810 builder.createBlock(&decl.reductionRegion(), decl.reductionRegion().end(), 811 {type, type}, {loc, loc}); 812 builder.setInsertionPointToEnd(&decl.reductionRegion().back()); 813 mlir::Value op1 = decl.reductionRegion().front().getArgument(0); 814 mlir::Value op2 = decl.reductionRegion().front().getArgument(1); 815 Value addRes = builder.create<mlir::arith::AddIOp>(loc, op1, op2); 816 builder.create<omp::YieldOp>(loc, addRes); 817 return decl; 818 } 819 820 static mlir::omp::ScheduleModifier 821 translateModifier(const Fortran::parser::OmpScheduleModifierType &m) { 822 switch (m.v) { 823 case Fortran::parser::OmpScheduleModifierType::ModType::Monotonic: 824 return mlir::omp::ScheduleModifier::monotonic; 825 case Fortran::parser::OmpScheduleModifierType::ModType::Nonmonotonic: 826 return mlir::omp::ScheduleModifier::nonmonotonic; 827 case Fortran::parser::OmpScheduleModifierType::ModType::Simd: 828 return mlir::omp::ScheduleModifier::simd; 829 } 830 return mlir::omp::ScheduleModifier::none; 831 } 832 833 static mlir::omp::ScheduleModifier 834 getScheduleModifier(const Fortran::parser::OmpScheduleClause &x) { 835 const auto &modifier = 836 std::get<std::optional<Fortran::parser::OmpScheduleModifier>>(x.t); 837 // The input may have the modifier any order, so we look for one that isn't 838 // SIMD. If modifier is not set at all, fall down to the bottom and return 839 // "none". 840 if (modifier) { 841 const auto &modType1 = 842 std::get<Fortran::parser::OmpScheduleModifier::Modifier1>(modifier->t); 843 if (modType1.v.v == 844 Fortran::parser::OmpScheduleModifierType::ModType::Simd) { 845 const auto &modType2 = std::get< 846 std::optional<Fortran::parser::OmpScheduleModifier::Modifier2>>( 847 modifier->t); 848 if (modType2 && 849 modType2->v.v != 850 Fortran::parser::OmpScheduleModifierType::ModType::Simd) 851 return translateModifier(modType2->v); 852 853 return mlir::omp::ScheduleModifier::none; 854 } 855 856 return translateModifier(modType1.v); 857 } 858 return mlir::omp::ScheduleModifier::none; 859 } 860 861 static mlir::omp::ScheduleModifier 862 getSIMDModifier(const Fortran::parser::OmpScheduleClause &x) { 863 const auto &modifier = 864 std::get<std::optional<Fortran::parser::OmpScheduleModifier>>(x.t); 865 // Either of the two possible modifiers in the input can be the SIMD modifier, 866 // so look in either one, and return simd if we find one. Not found = return 867 // "none". 868 if (modifier) { 869 const auto &modType1 = 870 std::get<Fortran::parser::OmpScheduleModifier::Modifier1>(modifier->t); 871 if (modType1.v.v == Fortran::parser::OmpScheduleModifierType::ModType::Simd) 872 return mlir::omp::ScheduleModifier::simd; 873 874 const auto &modType2 = std::get< 875 std::optional<Fortran::parser::OmpScheduleModifier::Modifier2>>( 876 modifier->t); 877 if (modType2 && modType2->v.v == 878 Fortran::parser::OmpScheduleModifierType::ModType::Simd) 879 return mlir::omp::ScheduleModifier::simd; 880 } 881 return mlir::omp::ScheduleModifier::none; 882 } 883 884 static std::string getReductionName( 885 Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp, 886 mlir::Type ty) { 887 std::string reductionName; 888 if (intrinsicOp == Fortran::parser::DefinedOperator::IntrinsicOperator::Add) 889 reductionName = "add_reduction"; 890 else 891 reductionName = "other_reduction"; 892 893 return (llvm::Twine(reductionName) + 894 (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) + 895 llvm::Twine(ty.getIntOrFloatBitWidth())) 896 .str(); 897 } 898 899 static void genOMP(Fortran::lower::AbstractConverter &converter, 900 Fortran::lower::pft::Evaluation &eval, 901 const Fortran::parser::OpenMPLoopConstruct &loopConstruct) { 902 903 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 904 mlir::Location currentLocation = converter.getCurrentLocation(); 905 llvm::SmallVector<mlir::Value> lowerBound, upperBound, step, linearVars, 906 linearStepVars, reductionVars; 907 mlir::Value scheduleChunkClauseOperand, ifClauseOperand; 908 mlir::Attribute scheduleClauseOperand, noWaitClauseOperand, 909 orderedClauseOperand, orderClauseOperand; 910 SmallVector<Attribute> reductionDeclSymbols; 911 Fortran::lower::StatementContext stmtCtx; 912 const auto &loopOpClauseList = std::get<Fortran::parser::OmpClauseList>( 913 std::get<Fortran::parser::OmpBeginLoopDirective>(loopConstruct.t).t); 914 915 const auto ompDirective = 916 std::get<Fortran::parser::OmpLoopDirective>( 917 std::get<Fortran::parser::OmpBeginLoopDirective>(loopConstruct.t).t) 918 .v; 919 if (llvm::omp::OMPD_parallel_do == ompDirective) { 920 createCombinedParallelOp<Fortran::parser::OmpBeginLoopDirective>( 921 converter, eval, 922 std::get<Fortran::parser::OmpBeginLoopDirective>(loopConstruct.t)); 923 } else if (llvm::omp::OMPD_do != ompDirective && 924 llvm::omp::OMPD_simd != ompDirective) { 925 TODO(converter.getCurrentLocation(), "Construct enclosing do loop"); 926 } 927 928 // Collect the loops to collapse. 929 auto *doConstructEval = &eval.getFirstNestedEvaluation(); 930 931 std::int64_t collapseValue = 932 Fortran::lower::getCollapseValue(loopOpClauseList); 933 std::size_t loopVarTypeSize = 0; 934 SmallVector<const Fortran::semantics::Symbol *> iv; 935 do { 936 auto *doLoop = &doConstructEval->getFirstNestedEvaluation(); 937 auto *doStmt = doLoop->getIf<Fortran::parser::NonLabelDoStmt>(); 938 assert(doStmt && "Expected do loop to be in the nested evaluation"); 939 const auto &loopControl = 940 std::get<std::optional<Fortran::parser::LoopControl>>(doStmt->t); 941 const Fortran::parser::LoopControl::Bounds *bounds = 942 std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u); 943 assert(bounds && "Expected bounds for worksharing do loop"); 944 Fortran::lower::StatementContext stmtCtx; 945 lowerBound.push_back(fir::getBase(converter.genExprValue( 946 *Fortran::semantics::GetExpr(bounds->lower), stmtCtx))); 947 upperBound.push_back(fir::getBase(converter.genExprValue( 948 *Fortran::semantics::GetExpr(bounds->upper), stmtCtx))); 949 if (bounds->step) { 950 step.push_back(fir::getBase(converter.genExprValue( 951 *Fortran::semantics::GetExpr(bounds->step), stmtCtx))); 952 } else { // If `step` is not present, assume it as `1`. 953 step.push_back(firOpBuilder.createIntegerConstant( 954 currentLocation, firOpBuilder.getIntegerType(32), 1)); 955 } 956 iv.push_back(bounds->name.thing.symbol); 957 loopVarTypeSize = std::max(loopVarTypeSize, 958 bounds->name.thing.symbol->GetUltimate().size()); 959 960 collapseValue--; 961 doConstructEval = 962 &*std::next(doConstructEval->getNestedEvaluations().begin()); 963 } while (collapseValue > 0); 964 965 for (const auto &clause : loopOpClauseList.v) { 966 if (const auto &scheduleClause = 967 std::get_if<Fortran::parser::OmpClause::Schedule>(&clause.u)) { 968 if (const auto &chunkExpr = 969 std::get<std::optional<Fortran::parser::ScalarIntExpr>>( 970 scheduleClause->v.t)) { 971 if (const auto *expr = Fortran::semantics::GetExpr(*chunkExpr)) { 972 scheduleChunkClauseOperand = 973 fir::getBase(converter.genExprValue(*expr, stmtCtx)); 974 } 975 } 976 } else if (const auto &ifClause = 977 std::get_if<Fortran::parser::OmpClause::If>(&clause.u)) { 978 ifClauseOperand = getIfClauseOperand(converter, stmtCtx, ifClause); 979 } else if (const auto &reductionClause = 980 std::get_if<Fortran::parser::OmpClause::Reduction>( 981 &clause.u)) { 982 omp::ReductionDeclareOp decl; 983 const auto &redOperator{std::get<Fortran::parser::OmpReductionOperator>( 984 reductionClause->v.t)}; 985 const auto &objectList{ 986 std::get<Fortran::parser::OmpObjectList>(reductionClause->v.t)}; 987 if (const auto &redDefinedOp = 988 std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) { 989 const auto &intrinsicOp{ 990 std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>( 991 redDefinedOp->u)}; 992 if (intrinsicOp != 993 Fortran::parser::DefinedOperator::IntrinsicOperator::Add) 994 TODO(currentLocation, 995 "Reduction of some intrinsic operators is not supported"); 996 for (const auto &ompObject : objectList.v) { 997 if (const auto *name{ 998 Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) { 999 if (const auto *symbol{name->symbol}) { 1000 mlir::Value symVal = converter.getSymbolAddress(*symbol); 1001 mlir::Type redType = 1002 symVal.getType().cast<fir::ReferenceType>().getEleTy(); 1003 reductionVars.push_back(symVal); 1004 if (redType.isIntOrIndex()) { 1005 decl = createReductionDecl( 1006 firOpBuilder, getReductionName(intrinsicOp, redType), 1007 redType, currentLocation); 1008 } else { 1009 TODO(currentLocation, 1010 "Reduction of some types is not supported"); 1011 } 1012 reductionDeclSymbols.push_back(SymbolRefAttr::get( 1013 firOpBuilder.getContext(), decl.sym_name())); 1014 } 1015 } 1016 } 1017 } else { 1018 TODO(currentLocation, 1019 "Reduction of intrinsic procedures is not supported"); 1020 } 1021 } 1022 } 1023 1024 // The types of lower bound, upper bound, and step are converted into the 1025 // type of the loop variable if necessary. 1026 mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize); 1027 for (unsigned it = 0; it < (unsigned)lowerBound.size(); it++) { 1028 lowerBound[it] = firOpBuilder.createConvert(currentLocation, loopVarType, 1029 lowerBound[it]); 1030 upperBound[it] = firOpBuilder.createConvert(currentLocation, loopVarType, 1031 upperBound[it]); 1032 step[it] = 1033 firOpBuilder.createConvert(currentLocation, loopVarType, step[it]); 1034 } 1035 1036 // 2.9.3.1 SIMD construct 1037 // TODO: Support all the clauses 1038 if (llvm::omp::OMPD_simd == ompDirective) { 1039 TypeRange resultType; 1040 auto SimdLoopOp = firOpBuilder.create<mlir::omp::SimdLoopOp>( 1041 currentLocation, resultType, lowerBound, upperBound, step, 1042 ifClauseOperand, /*inclusive=*/firOpBuilder.getUnitAttr()); 1043 createBodyOfOp<omp::SimdLoopOp>(SimdLoopOp, converter, currentLocation, 1044 eval, &loopOpClauseList, iv); 1045 return; 1046 } 1047 1048 // FIXME: Add support for following clauses: 1049 // 1. linear 1050 // 2. order 1051 auto wsLoopOp = firOpBuilder.create<mlir::omp::WsLoopOp>( 1052 currentLocation, lowerBound, upperBound, step, linearVars, linearStepVars, 1053 reductionVars, 1054 reductionDeclSymbols.empty() 1055 ? nullptr 1056 : mlir::ArrayAttr::get(firOpBuilder.getContext(), 1057 reductionDeclSymbols), 1058 scheduleClauseOperand.dyn_cast_or_null<omp::ClauseScheduleKindAttr>(), 1059 scheduleChunkClauseOperand, /*schedule_modifiers=*/nullptr, 1060 /*simd_modifier=*/nullptr, 1061 noWaitClauseOperand.dyn_cast_or_null<UnitAttr>(), 1062 orderedClauseOperand.dyn_cast_or_null<IntegerAttr>(), 1063 orderClauseOperand.dyn_cast_or_null<omp::ClauseOrderKindAttr>(), 1064 /*inclusive=*/firOpBuilder.getUnitAttr()); 1065 1066 // Handle attribute based clauses. 1067 for (const Fortran::parser::OmpClause &clause : loopOpClauseList.v) { 1068 if (const auto &orderedClause = 1069 std::get_if<Fortran::parser::OmpClause::Ordered>(&clause.u)) { 1070 if (orderedClause->v.has_value()) { 1071 const auto *expr = Fortran::semantics::GetExpr(orderedClause->v); 1072 const std::optional<std::int64_t> orderedClauseValue = 1073 Fortran::evaluate::ToInt64(*expr); 1074 wsLoopOp.ordered_valAttr( 1075 firOpBuilder.getI64IntegerAttr(*orderedClauseValue)); 1076 } else { 1077 wsLoopOp.ordered_valAttr(firOpBuilder.getI64IntegerAttr(0)); 1078 } 1079 } else if (const auto &scheduleClause = 1080 std::get_if<Fortran::parser::OmpClause::Schedule>( 1081 &clause.u)) { 1082 mlir::MLIRContext *context = firOpBuilder.getContext(); 1083 const auto &scheduleType = scheduleClause->v; 1084 const auto &scheduleKind = 1085 std::get<Fortran::parser::OmpScheduleClause::ScheduleType>( 1086 scheduleType.t); 1087 switch (scheduleKind) { 1088 case Fortran::parser::OmpScheduleClause::ScheduleType::Static: 1089 wsLoopOp.schedule_valAttr(omp::ClauseScheduleKindAttr::get( 1090 context, omp::ClauseScheduleKind::Static)); 1091 break; 1092 case Fortran::parser::OmpScheduleClause::ScheduleType::Dynamic: 1093 wsLoopOp.schedule_valAttr(omp::ClauseScheduleKindAttr::get( 1094 context, omp::ClauseScheduleKind::Dynamic)); 1095 break; 1096 case Fortran::parser::OmpScheduleClause::ScheduleType::Guided: 1097 wsLoopOp.schedule_valAttr(omp::ClauseScheduleKindAttr::get( 1098 context, omp::ClauseScheduleKind::Guided)); 1099 break; 1100 case Fortran::parser::OmpScheduleClause::ScheduleType::Auto: 1101 wsLoopOp.schedule_valAttr(omp::ClauseScheduleKindAttr::get( 1102 context, omp::ClauseScheduleKind::Auto)); 1103 break; 1104 case Fortran::parser::OmpScheduleClause::ScheduleType::Runtime: 1105 wsLoopOp.schedule_valAttr(omp::ClauseScheduleKindAttr::get( 1106 context, omp::ClauseScheduleKind::Runtime)); 1107 break; 1108 } 1109 mlir::omp::ScheduleModifier scheduleModifier = 1110 getScheduleModifier(scheduleClause->v); 1111 if (scheduleModifier != mlir::omp::ScheduleModifier::none) 1112 wsLoopOp.schedule_modifierAttr( 1113 omp::ScheduleModifierAttr::get(context, scheduleModifier)); 1114 if (getSIMDModifier(scheduleClause->v) != 1115 mlir::omp::ScheduleModifier::none) 1116 wsLoopOp.simd_modifierAttr(firOpBuilder.getUnitAttr()); 1117 } 1118 } 1119 // In FORTRAN `nowait` clause occur at the end of `omp do` directive. 1120 // i.e 1121 // !$omp do 1122 // <...> 1123 // !$omp end do nowait 1124 if (const auto &endClauseList = 1125 std::get<std::optional<Fortran::parser::OmpEndLoopDirective>>( 1126 loopConstruct.t)) { 1127 const auto &clauseList = 1128 std::get<Fortran::parser::OmpClauseList>((*endClauseList).t); 1129 for (const Fortran::parser::OmpClause &clause : clauseList.v) 1130 if (std::get_if<Fortran::parser::OmpClause::Nowait>(&clause.u)) 1131 wsLoopOp.nowaitAttr(firOpBuilder.getUnitAttr()); 1132 } 1133 1134 createBodyOfOp<omp::WsLoopOp>(wsLoopOp, converter, currentLocation, eval, 1135 &loopOpClauseList, iv); 1136 } 1137 1138 static void 1139 genOMP(Fortran::lower::AbstractConverter &converter, 1140 Fortran::lower::pft::Evaluation &eval, 1141 const Fortran::parser::OpenMPCriticalConstruct &criticalConstruct) { 1142 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 1143 mlir::Location currentLocation = converter.getCurrentLocation(); 1144 std::string name; 1145 const Fortran::parser::OmpCriticalDirective &cd = 1146 std::get<Fortran::parser::OmpCriticalDirective>(criticalConstruct.t); 1147 if (std::get<std::optional<Fortran::parser::Name>>(cd.t).has_value()) { 1148 name = 1149 std::get<std::optional<Fortran::parser::Name>>(cd.t).value().ToString(); 1150 } 1151 1152 uint64_t hint = 0; 1153 const auto &clauseList = std::get<Fortran::parser::OmpClauseList>(cd.t); 1154 for (const Fortran::parser::OmpClause &clause : clauseList.v) 1155 if (auto hintClause = 1156 std::get_if<Fortran::parser::OmpClause::Hint>(&clause.u)) { 1157 const auto *expr = Fortran::semantics::GetExpr(hintClause->v); 1158 hint = *Fortran::evaluate::ToInt64(*expr); 1159 break; 1160 } 1161 1162 mlir::omp::CriticalOp criticalOp = [&]() { 1163 if (name.empty()) { 1164 return firOpBuilder.create<mlir::omp::CriticalOp>(currentLocation, 1165 FlatSymbolRefAttr()); 1166 } else { 1167 mlir::ModuleOp module = firOpBuilder.getModule(); 1168 mlir::OpBuilder modBuilder(module.getBodyRegion()); 1169 auto global = module.lookupSymbol<mlir::omp::CriticalDeclareOp>(name); 1170 if (!global) 1171 global = modBuilder.create<mlir::omp::CriticalDeclareOp>( 1172 currentLocation, name, hint); 1173 return firOpBuilder.create<mlir::omp::CriticalOp>( 1174 currentLocation, mlir::FlatSymbolRefAttr::get( 1175 firOpBuilder.getContext(), global.sym_name())); 1176 } 1177 }(); 1178 createBodyOfOp<omp::CriticalOp>(criticalOp, converter, currentLocation, eval); 1179 } 1180 1181 static void 1182 genOMP(Fortran::lower::AbstractConverter &converter, 1183 Fortran::lower::pft::Evaluation &eval, 1184 const Fortran::parser::OpenMPSectionConstruct §ionConstruct) { 1185 1186 auto &firOpBuilder = converter.getFirOpBuilder(); 1187 auto currentLocation = converter.getCurrentLocation(); 1188 mlir::omp::SectionOp sectionOp = 1189 firOpBuilder.create<mlir::omp::SectionOp>(currentLocation); 1190 createBodyOfOp<omp::SectionOp>(sectionOp, converter, currentLocation, eval); 1191 } 1192 1193 // TODO: Add support for reduction 1194 static void 1195 genOMP(Fortran::lower::AbstractConverter &converter, 1196 Fortran::lower::pft::Evaluation &eval, 1197 const Fortran::parser::OpenMPSectionsConstruct §ionsConstruct) { 1198 auto &firOpBuilder = converter.getFirOpBuilder(); 1199 auto currentLocation = converter.getCurrentLocation(); 1200 SmallVector<Value> reductionVars, allocateOperands, allocatorOperands; 1201 mlir::UnitAttr noWaitClauseOperand; 1202 const auto §ionsClauseList = std::get<Fortran::parser::OmpClauseList>( 1203 std::get<Fortran::parser::OmpBeginSectionsDirective>(sectionsConstruct.t) 1204 .t); 1205 for (const Fortran::parser::OmpClause &clause : sectionsClauseList.v) { 1206 1207 // Reduction Clause 1208 if (std::get_if<Fortran::parser::OmpClause::Reduction>(&clause.u)) { 1209 TODO(currentLocation, "OMPC_Reduction"); 1210 1211 // Allocate clause 1212 } else if (const auto &allocateClause = 1213 std::get_if<Fortran::parser::OmpClause::Allocate>( 1214 &clause.u)) { 1215 genAllocateClause(converter, allocateClause->v, allocatorOperands, 1216 allocateOperands); 1217 } 1218 } 1219 const auto &endSectionsClauseList = 1220 std::get<Fortran::parser::OmpEndSectionsDirective>(sectionsConstruct.t); 1221 const auto &clauseList = 1222 std::get<Fortran::parser::OmpClauseList>(endSectionsClauseList.t); 1223 for (const auto &clause : clauseList.v) { 1224 // Nowait clause 1225 if (std::get_if<Fortran::parser::OmpClause::Nowait>(&clause.u)) { 1226 noWaitClauseOperand = firOpBuilder.getUnitAttr(); 1227 } 1228 } 1229 1230 llvm::omp::Directive dir = 1231 std::get<Fortran::parser::OmpSectionsDirective>( 1232 std::get<Fortran::parser::OmpBeginSectionsDirective>( 1233 sectionsConstruct.t) 1234 .t) 1235 .v; 1236 1237 // Parallel Sections Construct 1238 if (dir == llvm::omp::Directive::OMPD_parallel_sections) { 1239 createCombinedParallelOp<Fortran::parser::OmpBeginSectionsDirective>( 1240 converter, eval, 1241 std::get<Fortran::parser::OmpBeginSectionsDirective>( 1242 sectionsConstruct.t)); 1243 auto sectionsOp = firOpBuilder.create<mlir::omp::SectionsOp>( 1244 currentLocation, /*reduction_vars*/ ValueRange(), 1245 /*reductions=*/nullptr, allocateOperands, allocatorOperands, 1246 /*nowait=*/nullptr); 1247 createBodyOfOp(sectionsOp, converter, currentLocation, eval); 1248 1249 // Sections Construct 1250 } else if (dir == llvm::omp::Directive::OMPD_sections) { 1251 auto sectionsOp = firOpBuilder.create<mlir::omp::SectionsOp>( 1252 currentLocation, reductionVars, /*reductions = */ nullptr, 1253 allocateOperands, allocatorOperands, noWaitClauseOperand); 1254 createBodyOfOp<omp::SectionsOp>(sectionsOp, converter, currentLocation, 1255 eval); 1256 } 1257 } 1258 1259 static void genOmpAtomicHintAndMemoryOrderClauses( 1260 Fortran::lower::AbstractConverter &converter, 1261 const Fortran::parser::OmpAtomicClauseList &clauseList, 1262 mlir::IntegerAttr &hint, 1263 mlir::omp::ClauseMemoryOrderKindAttr &memory_order) { 1264 auto &firOpBuilder = converter.getFirOpBuilder(); 1265 for (const auto &clause : clauseList.v) { 1266 if (auto ompClause = std::get_if<Fortran::parser::OmpClause>(&clause.u)) { 1267 if (auto hintClause = 1268 std::get_if<Fortran::parser::OmpClause::Hint>(&ompClause->u)) { 1269 const auto *expr = Fortran::semantics::GetExpr(hintClause->v); 1270 uint64_t hintExprValue = *Fortran::evaluate::ToInt64(*expr); 1271 hint = firOpBuilder.getI64IntegerAttr(hintExprValue); 1272 } 1273 } else if (auto ompMemoryOrderClause = 1274 std::get_if<Fortran::parser::OmpMemoryOrderClause>( 1275 &clause.u)) { 1276 if (std::get_if<Fortran::parser::OmpClause::Acquire>( 1277 &ompMemoryOrderClause->v.u)) { 1278 memory_order = mlir::omp::ClauseMemoryOrderKindAttr::get( 1279 firOpBuilder.getContext(), omp::ClauseMemoryOrderKind::Acquire); 1280 } else if (std::get_if<Fortran::parser::OmpClause::Relaxed>( 1281 &ompMemoryOrderClause->v.u)) { 1282 memory_order = mlir::omp::ClauseMemoryOrderKindAttr::get( 1283 firOpBuilder.getContext(), omp::ClauseMemoryOrderKind::Relaxed); 1284 } else if (std::get_if<Fortran::parser::OmpClause::SeqCst>( 1285 &ompMemoryOrderClause->v.u)) { 1286 memory_order = mlir::omp::ClauseMemoryOrderKindAttr::get( 1287 firOpBuilder.getContext(), omp::ClauseMemoryOrderKind::Seq_cst); 1288 } else if (std::get_if<Fortran::parser::OmpClause::Release>( 1289 &ompMemoryOrderClause->v.u)) { 1290 memory_order = mlir::omp::ClauseMemoryOrderKindAttr::get( 1291 firOpBuilder.getContext(), omp::ClauseMemoryOrderKind::Release); 1292 } 1293 } 1294 } 1295 } 1296 1297 static void genOmpAtomicUpdateStatement( 1298 Fortran::lower::AbstractConverter &converter, 1299 Fortran::lower::pft::Evaluation &eval, 1300 const Fortran::parser::Variable &assignmentStmtVariable, 1301 const Fortran::parser::Expr &assignmentStmtExpr, 1302 const Fortran::parser::OmpAtomicClauseList *leftHandClauseList, 1303 const Fortran::parser::OmpAtomicClauseList *rightHandClauseList) { 1304 // Generate `omp.atomic.update` operation for atomic assignment statements 1305 auto &firOpBuilder = converter.getFirOpBuilder(); 1306 auto currentLocation = converter.getCurrentLocation(); 1307 Fortran::lower::StatementContext stmtCtx; 1308 1309 mlir::Value address = fir::getBase(converter.genExprAddr( 1310 *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx)); 1311 // If no hint clause is specified, the effect is as if 1312 // hint(omp_sync_hint_none) had been specified. 1313 mlir::IntegerAttr hint = nullptr; 1314 mlir::omp::ClauseMemoryOrderKindAttr memory_order = nullptr; 1315 if (leftHandClauseList) 1316 genOmpAtomicHintAndMemoryOrderClauses(converter, *leftHandClauseList, hint, 1317 memory_order); 1318 if (rightHandClauseList) 1319 genOmpAtomicHintAndMemoryOrderClauses(converter, *rightHandClauseList, hint, 1320 memory_order); 1321 auto atomicUpdateOp = firOpBuilder.create<mlir::omp::AtomicUpdateOp>( 1322 currentLocation, address, hint, memory_order); 1323 1324 //// Generate body of Atomic Update operation 1325 // If an argument for the region is provided then create the block with that 1326 // argument. Also update the symbol's address with the argument mlir value. 1327 mlir::Type varType = 1328 fir::getBase( 1329 converter.genExprValue( 1330 *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx)) 1331 .getType(); 1332 SmallVector<Type> varTys = {varType}; 1333 SmallVector<Location> locs = {currentLocation}; 1334 firOpBuilder.createBlock(&atomicUpdateOp.getRegion(), {}, varTys, locs); 1335 mlir::Value val = 1336 fir::getBase(atomicUpdateOp.getRegion().front().getArgument(0)); 1337 auto varDesignator = 1338 std::get_if<Fortran::common::Indirection<Fortran::parser::Designator>>( 1339 &assignmentStmtVariable.u); 1340 assert(varDesignator && "Variable designator for atomic update assignment " 1341 "statement does not exist"); 1342 const auto *name = getDesignatorNameIfDataRef(varDesignator->value()); 1343 assert(name && name->symbol && 1344 "No symbol attached to atomic update variable"); 1345 converter.bindSymbol(*name->symbol, val); 1346 // Set the insert for the terminator operation to go at the end of the 1347 // block. 1348 mlir::Block &block = atomicUpdateOp.getRegion().back(); 1349 firOpBuilder.setInsertionPointToEnd(&block); 1350 1351 mlir::Value result = fir::getBase(converter.genExprValue( 1352 *Fortran::semantics::GetExpr(assignmentStmtExpr), stmtCtx)); 1353 // Insert the terminator: YieldOp. 1354 firOpBuilder.create<mlir::omp::YieldOp>(currentLocation, result); 1355 // Reset the insert point to before the terminator. 1356 firOpBuilder.setInsertionPointToStart(&block); 1357 } 1358 1359 static void 1360 genOmpAtomicWrite(Fortran::lower::AbstractConverter &converter, 1361 Fortran::lower::pft::Evaluation &eval, 1362 const Fortran::parser::OmpAtomicWrite &atomicWrite) { 1363 auto &firOpBuilder = converter.getFirOpBuilder(); 1364 auto currentLocation = converter.getCurrentLocation(); 1365 // Get the value and address of atomic write operands. 1366 const Fortran::parser::OmpAtomicClauseList &rightHandClauseList = 1367 std::get<2>(atomicWrite.t); 1368 const Fortran::parser::OmpAtomicClauseList &leftHandClauseList = 1369 std::get<0>(atomicWrite.t); 1370 const auto &assignmentStmtExpr = 1371 std::get<Fortran::parser::Expr>(std::get<3>(atomicWrite.t).statement.t); 1372 const auto &assignmentStmtVariable = std::get<Fortran::parser::Variable>( 1373 std::get<3>(atomicWrite.t).statement.t); 1374 Fortran::lower::StatementContext stmtCtx; 1375 mlir::Value value = fir::getBase(converter.genExprValue( 1376 *Fortran::semantics::GetExpr(assignmentStmtExpr), stmtCtx)); 1377 mlir::Value address = fir::getBase(converter.genExprAddr( 1378 *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx)); 1379 // If no hint clause is specified, the effect is as if 1380 // hint(omp_sync_hint_none) had been specified. 1381 mlir::IntegerAttr hint = nullptr; 1382 mlir::omp::ClauseMemoryOrderKindAttr memory_order = nullptr; 1383 genOmpAtomicHintAndMemoryOrderClauses(converter, leftHandClauseList, hint, 1384 memory_order); 1385 genOmpAtomicHintAndMemoryOrderClauses(converter, rightHandClauseList, hint, 1386 memory_order); 1387 firOpBuilder.create<mlir::omp::AtomicWriteOp>(currentLocation, address, value, 1388 hint, memory_order); 1389 } 1390 1391 static void genOmpAtomicRead(Fortran::lower::AbstractConverter &converter, 1392 Fortran::lower::pft::Evaluation &eval, 1393 const Fortran::parser::OmpAtomicRead &atomicRead) { 1394 auto &firOpBuilder = converter.getFirOpBuilder(); 1395 auto currentLocation = converter.getCurrentLocation(); 1396 // Get the address of atomic read operands. 1397 const Fortran::parser::OmpAtomicClauseList &rightHandClauseList = 1398 std::get<2>(atomicRead.t); 1399 const Fortran::parser::OmpAtomicClauseList &leftHandClauseList = 1400 std::get<0>(atomicRead.t); 1401 const auto &assignmentStmtExpr = 1402 std::get<Fortran::parser::Expr>(std::get<3>(atomicRead.t).statement.t); 1403 const auto &assignmentStmtVariable = std::get<Fortran::parser::Variable>( 1404 std::get<3>(atomicRead.t).statement.t); 1405 Fortran::lower::StatementContext stmtCtx; 1406 mlir::Value from_address = fir::getBase(converter.genExprAddr( 1407 *Fortran::semantics::GetExpr(assignmentStmtExpr), stmtCtx)); 1408 mlir::Value to_address = fir::getBase(converter.genExprAddr( 1409 *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx)); 1410 // If no hint clause is specified, the effect is as if 1411 // hint(omp_sync_hint_none) had been specified. 1412 mlir::IntegerAttr hint = nullptr; 1413 mlir::omp::ClauseMemoryOrderKindAttr memory_order = nullptr; 1414 genOmpAtomicHintAndMemoryOrderClauses(converter, leftHandClauseList, hint, 1415 memory_order); 1416 genOmpAtomicHintAndMemoryOrderClauses(converter, rightHandClauseList, hint, 1417 memory_order); 1418 firOpBuilder.create<mlir::omp::AtomicReadOp>(currentLocation, from_address, 1419 to_address, hint, memory_order); 1420 } 1421 1422 static void 1423 genOmpAtomicUpdate(Fortran::lower::AbstractConverter &converter, 1424 Fortran::lower::pft::Evaluation &eval, 1425 const Fortran::parser::OmpAtomicUpdate &atomicUpdate) { 1426 const Fortran::parser::OmpAtomicClauseList &rightHandClauseList = 1427 std::get<2>(atomicUpdate.t); 1428 const Fortran::parser::OmpAtomicClauseList &leftHandClauseList = 1429 std::get<0>(atomicUpdate.t); 1430 const auto &assignmentStmtExpr = 1431 std::get<Fortran::parser::Expr>(std::get<3>(atomicUpdate.t).statement.t); 1432 const auto &assignmentStmtVariable = std::get<Fortran::parser::Variable>( 1433 std::get<3>(atomicUpdate.t).statement.t); 1434 1435 genOmpAtomicUpdateStatement(converter, eval, assignmentStmtVariable, 1436 assignmentStmtExpr, &leftHandClauseList, 1437 &rightHandClauseList); 1438 } 1439 1440 static void genOmpAtomic(Fortran::lower::AbstractConverter &converter, 1441 Fortran::lower::pft::Evaluation &eval, 1442 const Fortran::parser::OmpAtomic &atomicConstruct) { 1443 const Fortran::parser::OmpAtomicClauseList &atomicClauseList = 1444 std::get<Fortran::parser::OmpAtomicClauseList>(atomicConstruct.t); 1445 const auto &assignmentStmtExpr = std::get<Fortran::parser::Expr>( 1446 std::get<Fortran::parser::Statement<Fortran::parser::AssignmentStmt>>( 1447 atomicConstruct.t) 1448 .statement.t); 1449 const auto &assignmentStmtVariable = std::get<Fortran::parser::Variable>( 1450 std::get<Fortran::parser::Statement<Fortran::parser::AssignmentStmt>>( 1451 atomicConstruct.t) 1452 .statement.t); 1453 // If atomic-clause is not present on the construct, the behaviour is as if 1454 // the update clause is specified 1455 genOmpAtomicUpdateStatement(converter, eval, assignmentStmtVariable, 1456 assignmentStmtExpr, &atomicClauseList, nullptr); 1457 } 1458 1459 static void 1460 genOMP(Fortran::lower::AbstractConverter &converter, 1461 Fortran::lower::pft::Evaluation &eval, 1462 const Fortran::parser::OpenMPAtomicConstruct &atomicConstruct) { 1463 std::visit(Fortran::common::visitors{ 1464 [&](const Fortran::parser::OmpAtomicRead &atomicRead) { 1465 genOmpAtomicRead(converter, eval, atomicRead); 1466 }, 1467 [&](const Fortran::parser::OmpAtomicWrite &atomicWrite) { 1468 genOmpAtomicWrite(converter, eval, atomicWrite); 1469 }, 1470 [&](const Fortran::parser::OmpAtomic &atomicConstruct) { 1471 genOmpAtomic(converter, eval, atomicConstruct); 1472 }, 1473 [&](const Fortran::parser::OmpAtomicUpdate &atomicUpdate) { 1474 genOmpAtomicUpdate(converter, eval, atomicUpdate); 1475 }, 1476 [&](const auto &) { 1477 TODO(converter.getCurrentLocation(), "Atomic capture"); 1478 }, 1479 }, 1480 atomicConstruct.u); 1481 } 1482 1483 void Fortran::lower::genOpenMPConstruct( 1484 Fortran::lower::AbstractConverter &converter, 1485 Fortran::lower::pft::Evaluation &eval, 1486 const Fortran::parser::OpenMPConstruct &ompConstruct) { 1487 1488 std::visit( 1489 common::visitors{ 1490 [&](const Fortran::parser::OpenMPStandaloneConstruct 1491 &standaloneConstruct) { 1492 genOMP(converter, eval, standaloneConstruct); 1493 }, 1494 [&](const Fortran::parser::OpenMPSectionsConstruct 1495 §ionsConstruct) { 1496 genOMP(converter, eval, sectionsConstruct); 1497 }, 1498 [&](const Fortran::parser::OpenMPSectionConstruct §ionConstruct) { 1499 genOMP(converter, eval, sectionConstruct); 1500 }, 1501 [&](const Fortran::parser::OpenMPLoopConstruct &loopConstruct) { 1502 genOMP(converter, eval, loopConstruct); 1503 }, 1504 [&](const Fortran::parser::OpenMPDeclarativeAllocate 1505 &execAllocConstruct) { 1506 TODO(converter.getCurrentLocation(), "OpenMPDeclarativeAllocate"); 1507 }, 1508 [&](const Fortran::parser::OpenMPExecutableAllocate 1509 &execAllocConstruct) { 1510 TODO(converter.getCurrentLocation(), "OpenMPExecutableAllocate"); 1511 }, 1512 [&](const Fortran::parser::OpenMPBlockConstruct &blockConstruct) { 1513 genOMP(converter, eval, blockConstruct); 1514 }, 1515 [&](const Fortran::parser::OpenMPAtomicConstruct &atomicConstruct) { 1516 genOMP(converter, eval, atomicConstruct); 1517 }, 1518 [&](const Fortran::parser::OpenMPCriticalConstruct 1519 &criticalConstruct) { 1520 genOMP(converter, eval, criticalConstruct); 1521 }, 1522 }, 1523 ompConstruct.u); 1524 } 1525 1526 void Fortran::lower::genThreadprivateOp( 1527 Fortran::lower::AbstractConverter &converter, 1528 const Fortran::lower::pft::Variable &var) { 1529 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 1530 mlir::Location currentLocation = converter.getCurrentLocation(); 1531 1532 const Fortran::semantics::Symbol &sym = var.getSymbol(); 1533 mlir::Value symThreadprivateValue; 1534 if (const Fortran::semantics::Symbol *common = 1535 Fortran::semantics::FindCommonBlockContaining(sym.GetUltimate())) { 1536 mlir::Value commonValue = converter.getSymbolAddress(*common); 1537 if (mlir::isa<mlir::omp::ThreadprivateOp>(commonValue.getDefiningOp())) { 1538 // Generate ThreadprivateOp for a common block instead of its members and 1539 // only do it once for a common block. 1540 return; 1541 } 1542 // Generate ThreadprivateOp and rebind the common block. 1543 mlir::Value commonThreadprivateValue = 1544 firOpBuilder.create<mlir::omp::ThreadprivateOp>( 1545 currentLocation, commonValue.getType(), commonValue); 1546 converter.bindSymbol(*common, commonThreadprivateValue); 1547 // Generate the threadprivate value for the common block member. 1548 symThreadprivateValue = 1549 genCommonBlockMember(converter, sym, commonThreadprivateValue); 1550 } else { 1551 mlir::Value symValue = converter.getSymbolAddress(sym); 1552 symThreadprivateValue = firOpBuilder.create<mlir::omp::ThreadprivateOp>( 1553 currentLocation, symValue.getType(), symValue); 1554 } 1555 1556 fir::ExtendedValue sexv = converter.getSymbolExtendedValue(sym); 1557 fir::ExtendedValue symThreadprivateExv = 1558 getExtendedValue(sexv, symThreadprivateValue); 1559 converter.bindSymbol(sym, symThreadprivateExv); 1560 } 1561 1562 void Fortran::lower::genOpenMPDeclarativeConstruct( 1563 Fortran::lower::AbstractConverter &converter, 1564 Fortran::lower::pft::Evaluation &eval, 1565 const Fortran::parser::OpenMPDeclarativeConstruct &ompDeclConstruct) { 1566 1567 std::visit( 1568 common::visitors{ 1569 [&](const Fortran::parser::OpenMPDeclarativeAllocate 1570 &declarativeAllocate) { 1571 TODO(converter.getCurrentLocation(), "OpenMPDeclarativeAllocate"); 1572 }, 1573 [&](const Fortran::parser::OpenMPDeclareReductionConstruct 1574 &declareReductionConstruct) { 1575 TODO(converter.getCurrentLocation(), 1576 "OpenMPDeclareReductionConstruct"); 1577 }, 1578 [&](const Fortran::parser::OpenMPDeclareSimdConstruct 1579 &declareSimdConstruct) { 1580 TODO(converter.getCurrentLocation(), "OpenMPDeclareSimdConstruct"); 1581 }, 1582 [&](const Fortran::parser::OpenMPDeclareTargetConstruct 1583 &declareTargetConstruct) { 1584 TODO(converter.getCurrentLocation(), 1585 "OpenMPDeclareTargetConstruct"); 1586 }, 1587 [&](const Fortran::parser::OpenMPThreadprivate &threadprivate) { 1588 // The directive is lowered when instantiating the variable to 1589 // support the case of threadprivate variable declared in module. 1590 }, 1591 }, 1592 ompDeclConstruct.u); 1593 } 1594 1595 // Generate an OpenMP reduction operation. This implementation finds the chain : 1596 // load reduction var -> reduction_operation -> store reduction var and replaces 1597 // it with the reduction operation. 1598 // TODO: Currently assumes it is an integer addition reduction. Generalize this 1599 // for various reduction operation types. 1600 // TODO: Generate the reduction operation during lowering instead of creating 1601 // and removing operations since this is not a robust approach. Also, removing 1602 // ops in the builder (instead of a rewriter) is probably not the best approach. 1603 void Fortran::lower::genOpenMPReduction( 1604 Fortran::lower::AbstractConverter &converter, 1605 const Fortran::parser::OmpClauseList &clauseList) { 1606 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 1607 1608 for (const auto &clause : clauseList.v) { 1609 if (const auto &reductionClause = 1610 std::get_if<Fortran::parser::OmpClause::Reduction>(&clause.u)) { 1611 const auto &redOperator{std::get<Fortran::parser::OmpReductionOperator>( 1612 reductionClause->v.t)}; 1613 const auto &objectList{ 1614 std::get<Fortran::parser::OmpObjectList>(reductionClause->v.t)}; 1615 if (auto reductionOp = 1616 std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) { 1617 const auto &intrinsicOp{ 1618 std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>( 1619 reductionOp->u)}; 1620 if (intrinsicOp != 1621 Fortran::parser::DefinedOperator::IntrinsicOperator::Add) 1622 continue; 1623 for (const auto &ompObject : objectList.v) { 1624 if (const auto *name{ 1625 Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) { 1626 if (const auto *symbol{name->symbol}) { 1627 mlir::Value symVal = converter.getSymbolAddress(*symbol); 1628 mlir::Type redType = 1629 symVal.getType().cast<fir::ReferenceType>().getEleTy(); 1630 if (!redType.isIntOrIndex()) 1631 continue; 1632 for (mlir::OpOperand &use1 : symVal.getUses()) { 1633 if (auto load = mlir::dyn_cast<fir::LoadOp>(use1.getOwner())) { 1634 mlir::Value loadVal = load.getRes(); 1635 for (mlir::OpOperand &use2 : loadVal.getUses()) { 1636 if (auto add = mlir::dyn_cast<mlir::arith::AddIOp>( 1637 use2.getOwner())) { 1638 mlir::Value addRes = add.getResult(); 1639 for (mlir::OpOperand &use3 : addRes.getUses()) { 1640 if (auto store = 1641 mlir::dyn_cast<fir::StoreOp>(use3.getOwner())) { 1642 if (store.getMemref() == symVal) { 1643 // Chain found! Now replace load->reduction->store 1644 // with the OpenMP reduction operation. 1645 mlir::OpBuilder::InsertPoint insertPtDel = 1646 firOpBuilder.saveInsertionPoint(); 1647 firOpBuilder.setInsertionPoint(add); 1648 if (add.getLhs() == loadVal) { 1649 firOpBuilder.create<mlir::omp::ReductionOp>( 1650 add.getLoc(), add.getRhs(), symVal); 1651 } else { 1652 firOpBuilder.create<mlir::omp::ReductionOp>( 1653 add.getLoc(), add.getLhs(), symVal); 1654 } 1655 store.erase(); 1656 add.erase(); 1657 load.erase(); 1658 firOpBuilder.restoreInsertionPoint(insertPtDel); 1659 } 1660 } 1661 } 1662 } 1663 } 1664 } 1665 } 1666 } 1667 } 1668 } 1669 } 1670 } 1671 } 1672 } 1673