1 //===-- OpenMP.cpp -- Open MP directive lowering --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "flang/Lower/OpenMP.h" 14 #include "flang/Common/idioms.h" 15 #include "flang/Lower/Bridge.h" 16 #include "flang/Lower/ConvertExpr.h" 17 #include "flang/Lower/PFTBuilder.h" 18 #include "flang/Lower/StatementContext.h" 19 #include "flang/Optimizer/Builder/BoxValue.h" 20 #include "flang/Optimizer/Builder/FIRBuilder.h" 21 #include "flang/Optimizer/Builder/Todo.h" 22 #include "flang/Parser/parse-tree.h" 23 #include "flang/Semantics/tools.h" 24 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 25 #include "mlir/Dialect/SCF/IR/SCF.h" 26 #include "llvm/Frontend/OpenMP/OMPConstants.h" 27 28 using namespace mlir; 29 30 int64_t Fortran::lower::getCollapseValue( 31 const Fortran::parser::OmpClauseList &clauseList) { 32 for (const auto &clause : clauseList.v) { 33 if (const auto &collapseClause = 34 std::get_if<Fortran::parser::OmpClause::Collapse>(&clause.u)) { 35 const auto *expr = Fortran::semantics::GetExpr(collapseClause->v); 36 return Fortran::evaluate::ToInt64(*expr).value(); 37 } 38 } 39 return 1; 40 } 41 42 static const Fortran::parser::Name * 43 getDesignatorNameIfDataRef(const Fortran::parser::Designator &designator) { 44 const auto *dataRef = std::get_if<Fortran::parser::DataRef>(&designator.u); 45 return dataRef ? std::get_if<Fortran::parser::Name>(&dataRef->u) : nullptr; 46 } 47 48 static Fortran::semantics::Symbol * 49 getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject) { 50 Fortran::semantics::Symbol *sym = nullptr; 51 std::visit(Fortran::common::visitors{ 52 [&](const Fortran::parser::Designator &designator) { 53 if (const Fortran::parser::Name *name = 54 getDesignatorNameIfDataRef(designator)) { 55 sym = name->symbol; 56 } 57 }, 58 [&](const Fortran::parser::Name &name) { sym = name.symbol; }}, 59 ompObject.u); 60 return sym; 61 } 62 63 template <typename T> 64 static void createPrivateVarSyms(Fortran::lower::AbstractConverter &converter, 65 const T *clause, 66 Block *lastPrivBlock = nullptr) { 67 const Fortran::parser::OmpObjectList &ompObjectList = clause->v; 68 for (const Fortran::parser::OmpObject &ompObject : ompObjectList.v) { 69 Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); 70 // Privatization for symbols which are pre-determined (like loop index 71 // variables) happen separately, for everything else privatize here. 72 if (sym->test(Fortran::semantics::Symbol::Flag::OmpPreDetermined)) 73 continue; 74 bool success = converter.createHostAssociateVarClone(*sym); 75 (void)success; 76 assert(success && "Privatization failed due to existing binding"); 77 if constexpr (std::is_same_v<T, Fortran::parser::OmpClause::Firstprivate>) { 78 converter.copyHostAssociateVar(*sym); 79 } else if constexpr (std::is_same_v< 80 T, Fortran::parser::OmpClause::Lastprivate>) { 81 converter.copyHostAssociateVar(*sym, lastPrivBlock); 82 } 83 } 84 } 85 86 template <typename Op> 87 static bool privatizeVars(Op &op, Fortran::lower::AbstractConverter &converter, 88 const Fortran::parser::OmpClauseList &opClauseList, 89 Fortran::lower::pft::Evaluation &eval) { 90 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 91 auto insPt = firOpBuilder.saveInsertionPoint(); 92 firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock()); 93 bool hasFirstPrivateOp = false; 94 bool hasLastPrivateOp = false; 95 // The symbols in private/firstprivate clause or region under 96 // default(private/firstprivate) clause. 97 // TODO: Current implementation of default clause is very fragile. Implement 98 // it instead if a pre-FIR pass 99 llvm::SetVector<const Fortran::semantics::Symbol *> symbols; 100 llvm::SetVector<const Fortran::semantics::Symbol *> sharedSymbols; 101 std::map<Fortran::semantics::SourceName, const Fortran::semantics::Symbol *> 102 uniqueSymbolMap; 103 auto collectOmpObjectListSymbol = 104 [&](const Fortran::parser::OmpObjectList &ompObjectList, 105 llvm::SetVector<const Fortran::semantics::Symbol *> &symbolSet) { 106 for (const Fortran::parser::OmpObject &ompObject : ompObjectList.v) { 107 Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); 108 symbolSet.insert(sym); 109 } 110 }; 111 112 auto removeDuplicateSymbols = 113 [&](llvm::SetVector<const Fortran::semantics::Symbol *> &symbols) { 114 // default clause in nested directives can create two symbols belonging 115 // to the same entity resulting in double privatization. Reduce the 116 // count of such symbols to 1. 117 for (auto sym : symbols) 118 if (auto it{uniqueSymbolMap.find(sym->name())}; 119 it == uniqueSymbolMap.end()) 120 uniqueSymbolMap.insert( 121 std::pair<Fortran::semantics::SourceName, 122 const Fortran::semantics::Symbol *>(sym->name(), 123 sym)); 124 }; 125 // We need just one ICmpOp for multiple LastPrivate clauses. 126 mlir::arith::CmpIOp cmpOp; 127 // Collect the privatized symbols for private/firstprivate/default clause. 128 for (const Fortran::parser::OmpClause &clause : opClauseList.v) { 129 if (const auto &privateClause = 130 std::get_if<Fortran::parser::OmpClause::Private>(&clause.u)) { 131 collectOmpObjectListSymbol(privateClause->v, symbols); 132 removeDuplicateSymbols(symbols); 133 } else if (const auto &firstPrivateClause = 134 std::get_if<Fortran::parser::OmpClause::Firstprivate>( 135 &clause.u)) { 136 collectOmpObjectListSymbol(firstPrivateClause->v, symbols); 137 removeDuplicateSymbols(symbols); 138 hasFirstPrivateOp = true; 139 } else if (const auto &sharedClause = 140 std::get_if<Fortran::parser::OmpClause::Shared>(&clause.u)) { 141 // `shared(...)` does not define a privatization flag. In case of nested 142 // block constructs, presence of `default(firstprivate)` or 143 // `default(private)` in the inner block construct may lead to privatizing 144 // a shared variable in the outer block construct. To prevent that, 145 // explicitly collect the shared symbols and remove them from the list of 146 // symbols to be privatized. 147 collectOmpObjectListSymbol(sharedClause->v, sharedSymbols); 148 } else if (const auto &lastPrivateClause = 149 std::get_if<Fortran::parser::OmpClause::Lastprivate>( 150 &clause.u)) { 151 // TODO: Add lastprivate support for sections construct, simd construct 152 if (std::is_same_v<Op, omp::WsLoopOp>) { 153 omp::WsLoopOp *wsLoopOp = dyn_cast<omp::WsLoopOp>(&op); 154 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 155 auto insPt = firOpBuilder.saveInsertionPoint(); 156 157 // Our goal here is to introduce the following control flow 158 // just before exiting the worksharing loop. 159 // Say our wsloop is as follows: 160 // 161 // omp.wsloop { 162 // ... 163 // store 164 // omp.yield 165 // } 166 // 167 // We want to convert it to the following: 168 // 169 // omp.wsloop { 170 // ... 171 // store 172 // %cmp = llvm.icmp "eq" %iv %ub 173 // scf.if %cmp { 174 // ^%lpv_update_blk: 175 // } 176 // omp.yield 177 // } 178 179 Operation *lastOper = wsLoopOp->region().back().getTerminator(); 180 181 firOpBuilder.setInsertionPoint(lastOper); 182 183 // TODO: The following will not work when there is collapse present. 184 // Have to modify this in future. 185 for (const Fortran::parser::OmpClause &clause : opClauseList.v) 186 if (const auto &collapseClause = 187 std::get_if<Fortran::parser::OmpClause::Collapse>(&clause.u)) 188 TODO(converter.getCurrentLocation(), 189 "Collapse clause with lastprivate"); 190 // Only generate the compare once in presence of multiple LastPrivate 191 // clauses 192 if (!hasLastPrivateOp) { 193 cmpOp = firOpBuilder.create<mlir::arith::CmpIOp>( 194 wsLoopOp->getLoc(), mlir::arith::CmpIPredicate::eq, 195 wsLoopOp->getRegion().front().getArguments()[0], 196 wsLoopOp->upperBound()[0]); 197 } 198 mlir::scf::IfOp ifOp = firOpBuilder.create<mlir::scf::IfOp>( 199 wsLoopOp->getLoc(), cmpOp, /*else*/ false); 200 201 firOpBuilder.restoreInsertionPoint(insPt); 202 createPrivateVarSyms(converter, lastPrivateClause, 203 &(ifOp.getThenRegion().front())); 204 } else { 205 TODO(converter.getCurrentLocation(), 206 "lastprivate clause in constructs other than work-share loop"); 207 } 208 hasLastPrivateOp = true; 209 } 210 } 211 212 for (const Fortran::parser::OmpClause &clause : opClauseList.v) { 213 if (const auto &defaultClause = 214 std::get_if<Fortran::parser::OmpClause::Default>(&clause.u)) { 215 if (defaultClause->v.v == 216 Fortran::parser::OmpDefaultClause::Type::Private) { 217 converter.collectSymbolSet(eval, symbols, 218 Fortran::semantics::Symbol::Flag::OmpPrivate, 219 /*checkHostAssoicatedSymbols=*/true); 220 removeDuplicateSymbols(symbols); 221 } else if (defaultClause->v.v == 222 Fortran::parser::OmpDefaultClause::Type::Firstprivate) { 223 converter.collectSymbolSet( 224 eval, symbols, Fortran::semantics::Symbol::Flag::OmpFirstPrivate, 225 /*checkHostAssoicatedSymbols=*/true); 226 removeDuplicateSymbols(symbols); 227 } 228 } 229 } 230 231 for (auto sharedSymbol : sharedSymbols) { 232 auto it = uniqueSymbolMap.find(sharedSymbol->name()); 233 if (it != uniqueSymbolMap.end()) { 234 uniqueSymbolMap.erase(it); 235 } 236 } 237 for (auto uniqueSymbolMapPair : uniqueSymbolMap) { 238 // Privatization for symbols which are pre-determined (like loop index 239 // variables) happen separately, for everything else privatize here. 240 const Fortran::semantics::Symbol *sym = uniqueSymbolMapPair.second; 241 if (sym->test(Fortran::semantics::Symbol::Flag::OmpPreDetermined)) 242 continue; 243 bool success = converter.createHostAssociateVarClone(*sym); 244 (void)success; 245 assert(success && "Privatization failed due to existing binding"); 246 if (sym->test(Fortran::semantics::Symbol::Flag::OmpFirstPrivate)) { 247 converter.copyHostAssociateVar(*sym); 248 hasFirstPrivateOp = true; 249 } 250 } 251 252 if (hasFirstPrivateOp || hasLastPrivateOp) 253 firOpBuilder.create<mlir::omp::BarrierOp>(converter.getCurrentLocation()); 254 firOpBuilder.restoreInsertionPoint(insPt); 255 return hasLastPrivateOp; 256 } 257 258 /// The COMMON block is a global structure. \p commonValue is the base address 259 /// of the the COMMON block. As the offset from the symbol \p sym, generate the 260 /// COMMON block member value (commonValue + offset) for the symbol. 261 /// FIXME: Share the code with `instantiateCommon` in ConvertVariable.cpp. 262 static mlir::Value 263 genCommonBlockMember(Fortran::lower::AbstractConverter &converter, 264 const Fortran::semantics::Symbol &sym, 265 mlir::Value commonValue) { 266 auto &firOpBuilder = converter.getFirOpBuilder(); 267 mlir::Location currentLocation = converter.getCurrentLocation(); 268 mlir::IntegerType i8Ty = firOpBuilder.getIntegerType(8); 269 mlir::Type i8Ptr = firOpBuilder.getRefType(i8Ty); 270 mlir::Type seqTy = firOpBuilder.getRefType(firOpBuilder.getVarLenSeqTy(i8Ty)); 271 mlir::Value base = 272 firOpBuilder.createConvert(currentLocation, seqTy, commonValue); 273 std::size_t byteOffset = sym.GetUltimate().offset(); 274 mlir::Value offs = firOpBuilder.createIntegerConstant( 275 currentLocation, firOpBuilder.getIndexType(), byteOffset); 276 mlir::Value varAddr = firOpBuilder.create<fir::CoordinateOp>( 277 currentLocation, i8Ptr, base, mlir::ValueRange{offs}); 278 mlir::Type symType = converter.genType(sym); 279 return firOpBuilder.createConvert(currentLocation, 280 firOpBuilder.getRefType(symType), varAddr); 281 } 282 283 // Get the extended value for \p val by extracting additional variable 284 // information from \p base. 285 static fir::ExtendedValue getExtendedValue(fir::ExtendedValue base, 286 mlir::Value val) { 287 return base.match( 288 [&](const fir::MutableBoxValue &box) -> fir::ExtendedValue { 289 return fir::MutableBoxValue(val, box.nonDeferredLenParams(), {}); 290 }, 291 [&](const auto &) -> fir::ExtendedValue { 292 return fir::substBase(base, val); 293 }); 294 } 295 296 static void threadPrivatizeVars(Fortran::lower::AbstractConverter &converter, 297 Fortran::lower::pft::Evaluation &eval) { 298 auto &firOpBuilder = converter.getFirOpBuilder(); 299 mlir::Location currentLocation = converter.getCurrentLocation(); 300 auto insPt = firOpBuilder.saveInsertionPoint(); 301 firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock()); 302 303 // Get the original ThreadprivateOp corresponding to the symbol and use the 304 // symbol value from that opeartion to create one ThreadprivateOp copy 305 // operation inside the parallel region. 306 auto genThreadprivateOp = [&](Fortran::lower::SymbolRef sym) -> mlir::Value { 307 mlir::Value symOriThreadprivateValue = converter.getSymbolAddress(sym); 308 mlir::Operation *op = symOriThreadprivateValue.getDefiningOp(); 309 assert(mlir::isa<mlir::omp::ThreadprivateOp>(op) && 310 "The threadprivate operation not created"); 311 mlir::Value symValue = 312 mlir::dyn_cast<mlir::omp::ThreadprivateOp>(op).sym_addr(); 313 return firOpBuilder.create<mlir::omp::ThreadprivateOp>( 314 currentLocation, symValue.getType(), symValue); 315 }; 316 317 llvm::SetVector<const Fortran::semantics::Symbol *> threadprivateSyms; 318 converter.collectSymbolSet(eval, threadprivateSyms, 319 Fortran::semantics::Symbol::Flag::OmpThreadprivate, 320 /*isUltimateSymbol=*/false); 321 std::set<Fortran::semantics::SourceName> threadprivateSymNames; 322 323 // For a COMMON block, the ThreadprivateOp is generated for itself instead of 324 // its members, so only bind the value of the new copied ThreadprivateOp 325 // inside the parallel region to the common block symbol only once for 326 // multiple members in one COMMON block. 327 llvm::SetVector<const Fortran::semantics::Symbol *> commonSyms; 328 for (std::size_t i = 0; i < threadprivateSyms.size(); i++) { 329 auto sym = threadprivateSyms[i]; 330 mlir::Value symThreadprivateValue; 331 // The variable may be used more than once, and each reference has one 332 // symbol with the same name. Only do once for references of one variable. 333 if (threadprivateSymNames.find(sym->name()) != threadprivateSymNames.end()) 334 continue; 335 threadprivateSymNames.insert(sym->name()); 336 if (const Fortran::semantics::Symbol *common = 337 Fortran::semantics::FindCommonBlockContaining(sym->GetUltimate())) { 338 mlir::Value commonThreadprivateValue; 339 if (commonSyms.contains(common)) { 340 commonThreadprivateValue = converter.getSymbolAddress(*common); 341 } else { 342 commonThreadprivateValue = genThreadprivateOp(*common); 343 converter.bindSymbol(*common, commonThreadprivateValue); 344 commonSyms.insert(common); 345 } 346 symThreadprivateValue = 347 genCommonBlockMember(converter, *sym, commonThreadprivateValue); 348 } else { 349 symThreadprivateValue = genThreadprivateOp(*sym); 350 } 351 352 fir::ExtendedValue sexv = converter.getSymbolExtendedValue(*sym); 353 fir::ExtendedValue symThreadprivateExv = 354 getExtendedValue(sexv, symThreadprivateValue); 355 converter.bindSymbol(*sym, symThreadprivateExv); 356 } 357 358 firOpBuilder.restoreInsertionPoint(insPt); 359 } 360 361 static void 362 genCopyinClause(Fortran::lower::AbstractConverter &converter, 363 const Fortran::parser::OmpClauseList &opClauseList) { 364 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 365 mlir::OpBuilder::InsertPoint insPt = firOpBuilder.saveInsertionPoint(); 366 firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock()); 367 bool hasCopyin = false; 368 for (const Fortran::parser::OmpClause &clause : opClauseList.v) { 369 if (const auto ©inClause = 370 std::get_if<Fortran::parser::OmpClause::Copyin>(&clause.u)) { 371 hasCopyin = true; 372 const Fortran::parser::OmpObjectList &ompObjectList = copyinClause->v; 373 for (const Fortran::parser::OmpObject &ompObject : ompObjectList.v) { 374 Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); 375 if (sym->has<Fortran::semantics::CommonBlockDetails>()) 376 TODO(converter.getCurrentLocation(), "common block in Copyin clause"); 377 if (Fortran::semantics::IsAllocatableOrPointer(sym->GetUltimate())) 378 TODO(converter.getCurrentLocation(), 379 "pointer or allocatable variables in Copyin clause"); 380 assert(sym->has<Fortran::semantics::HostAssocDetails>() && 381 "No host-association found"); 382 converter.copyHostAssociateVar(*sym); 383 } 384 } 385 } 386 // [OMP 5.0, 2.19.6.1] The copy is done after the team is formed and prior to 387 // the execution of the associated structured block. Emit implicit barrier to 388 // synchronize threads and avoid data races on propagation master's thread 389 // values of threadprivate variables to local instances of that variables of 390 // all other implicit threads. 391 if (hasCopyin) 392 firOpBuilder.create<mlir::omp::BarrierOp>(converter.getCurrentLocation()); 393 firOpBuilder.restoreInsertionPoint(insPt); 394 } 395 396 static void genObjectList(const Fortran::parser::OmpObjectList &objectList, 397 Fortran::lower::AbstractConverter &converter, 398 llvm::SmallVectorImpl<Value> &operands) { 399 auto addOperands = [&](Fortran::lower::SymbolRef sym) { 400 const mlir::Value variable = converter.getSymbolAddress(sym); 401 if (variable) { 402 operands.push_back(variable); 403 } else { 404 if (const auto *details = 405 sym->detailsIf<Fortran::semantics::HostAssocDetails>()) { 406 operands.push_back(converter.getSymbolAddress(details->symbol())); 407 converter.copySymbolBinding(details->symbol(), sym); 408 } 409 } 410 }; 411 for (const Fortran::parser::OmpObject &ompObject : objectList.v) { 412 Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); 413 addOperands(*sym); 414 } 415 } 416 417 static mlir::Type getLoopVarType(Fortran::lower::AbstractConverter &converter, 418 std::size_t loopVarTypeSize) { 419 // OpenMP runtime requires 32-bit or 64-bit loop variables. 420 loopVarTypeSize = loopVarTypeSize * 8; 421 if (loopVarTypeSize < 32) { 422 loopVarTypeSize = 32; 423 } else if (loopVarTypeSize > 64) { 424 loopVarTypeSize = 64; 425 mlir::emitWarning(converter.getCurrentLocation(), 426 "OpenMP loop iteration variable cannot have more than 64 " 427 "bits size and will be narrowed into 64 bits."); 428 } 429 assert((loopVarTypeSize == 32 || loopVarTypeSize == 64) && 430 "OpenMP loop iteration variable size must be transformed into 32-bit " 431 "or 64-bit"); 432 return converter.getFirOpBuilder().getIntegerType(loopVarTypeSize); 433 } 434 435 /// Create empty blocks for the current region. 436 /// These blocks replace blocks parented to an enclosing region. 437 void createEmptyRegionBlocks( 438 fir::FirOpBuilder &firOpBuilder, 439 std::list<Fortran::lower::pft::Evaluation> &evaluationList) { 440 auto *region = &firOpBuilder.getRegion(); 441 for (auto &eval : evaluationList) { 442 if (eval.block) { 443 if (eval.block->empty()) { 444 eval.block->erase(); 445 eval.block = firOpBuilder.createBlock(region); 446 } else { 447 [[maybe_unused]] auto &terminatorOp = eval.block->back(); 448 assert((mlir::isa<mlir::omp::TerminatorOp>(terminatorOp) || 449 mlir::isa<mlir::omp::YieldOp>(terminatorOp)) && 450 "expected terminator op"); 451 } 452 } 453 if (!eval.isDirective() && eval.hasNestedEvaluations()) 454 createEmptyRegionBlocks(firOpBuilder, eval.getNestedEvaluations()); 455 } 456 } 457 458 void resetBeforeTerminator(fir::FirOpBuilder &firOpBuilder, 459 mlir::Operation *storeOp, mlir::Block &block) { 460 if (storeOp) 461 firOpBuilder.setInsertionPointAfter(storeOp); 462 else 463 firOpBuilder.setInsertionPointToStart(&block); 464 } 465 466 /// Create the body (block) for an OpenMP Operation. 467 /// 468 /// \param [in] op - the operation the body belongs to. 469 /// \param [inout] converter - converter to use for the clauses. 470 /// \param [in] loc - location in source code. 471 /// \param [in] eval - current PFT node/evaluation. 472 /// \oaran [in] clauses - list of clauses to process. 473 /// \param [in] args - block arguments (induction variable[s]) for the 474 //// region. 475 /// \param [in] outerCombined - is this an outer operation - prevents 476 /// privatization. 477 template <typename Op> 478 static void 479 createBodyOfOp(Op &op, Fortran::lower::AbstractConverter &converter, 480 mlir::Location &loc, Fortran::lower::pft::Evaluation &eval, 481 const Fortran::parser::OmpClauseList *clauses = nullptr, 482 const SmallVector<const Fortran::semantics::Symbol *> &args = {}, 483 bool outerCombined = false) { 484 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 485 // If an argument for the region is provided then create the block with that 486 // argument. Also update the symbol's address with the mlir argument value. 487 // e.g. For loops the argument is the induction variable. And all further 488 // uses of the induction variable should use this mlir value. 489 mlir::Operation *storeOp = nullptr; 490 if (args.size()) { 491 std::size_t loopVarTypeSize = 0; 492 for (const Fortran::semantics::Symbol *arg : args) 493 loopVarTypeSize = std::max(loopVarTypeSize, arg->GetUltimate().size()); 494 mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize); 495 SmallVector<Type> tiv; 496 SmallVector<Location> locs; 497 for (int i = 0; i < (int)args.size(); i++) { 498 tiv.push_back(loopVarType); 499 locs.push_back(loc); 500 } 501 firOpBuilder.createBlock(&op.getRegion(), {}, tiv, locs); 502 int argIndex = 0; 503 // The argument is not currently in memory, so make a temporary for the 504 // argument, and store it there, then bind that location to the argument. 505 for (const Fortran::semantics::Symbol *arg : args) { 506 mlir::Value val = 507 fir::getBase(op.getRegion().front().getArgument(argIndex)); 508 mlir::Value temp = firOpBuilder.createTemporary( 509 loc, loopVarType, 510 llvm::ArrayRef<mlir::NamedAttribute>{ 511 Fortran::lower::getAdaptToByRefAttr(firOpBuilder)}); 512 storeOp = firOpBuilder.create<fir::StoreOp>(loc, val, temp); 513 converter.bindSymbol(*arg, temp); 514 argIndex++; 515 } 516 } else { 517 firOpBuilder.createBlock(&op.getRegion()); 518 } 519 // Set the insert for the terminator operation to go at the end of the 520 // block - this is either empty or the block with the stores above, 521 // the end of the block works for both. 522 mlir::Block &block = op.getRegion().back(); 523 firOpBuilder.setInsertionPointToEnd(&block); 524 525 // If it is an unstructured region and is not the outer region of a combined 526 // construct, create empty blocks for all evaluations. 527 if (eval.lowerAsUnstructured() && !outerCombined) 528 createEmptyRegionBlocks(firOpBuilder, eval.getNestedEvaluations()); 529 530 // Insert the terminator. 531 if constexpr (std::is_same_v<Op, omp::WsLoopOp> || 532 std::is_same_v<Op, omp::SimdLoopOp>) { 533 mlir::ValueRange results; 534 firOpBuilder.create<mlir::omp::YieldOp>(loc, results); 535 } else { 536 firOpBuilder.create<mlir::omp::TerminatorOp>(loc); 537 } 538 539 // Reset the insert point to before the terminator. 540 resetBeforeTerminator(firOpBuilder, storeOp, block); 541 542 // Handle privatization. Do not privatize if this is the outer operation. 543 if (clauses && !outerCombined) { 544 bool lastPrivateOp = privatizeVars(op, converter, *clauses, eval); 545 // LastPrivatization, due to introduction of 546 // new control flow, changes the insertion point, 547 // thus restore it. 548 // TODO: Clean up later a bit to avoid this many sets and resets. 549 if (lastPrivateOp) 550 resetBeforeTerminator(firOpBuilder, storeOp, block); 551 } 552 553 if constexpr (std::is_same_v<Op, omp::ParallelOp>) { 554 threadPrivatizeVars(converter, eval); 555 if (clauses) 556 genCopyinClause(converter, *clauses); 557 } 558 } 559 560 static void genOMP(Fortran::lower::AbstractConverter &converter, 561 Fortran::lower::pft::Evaluation &eval, 562 const Fortran::parser::OpenMPSimpleStandaloneConstruct 563 &simpleStandaloneConstruct) { 564 const auto &directive = 565 std::get<Fortran::parser::OmpSimpleStandaloneDirective>( 566 simpleStandaloneConstruct.t); 567 switch (directive.v) { 568 default: 569 break; 570 case llvm::omp::Directive::OMPD_barrier: 571 converter.getFirOpBuilder().create<mlir::omp::BarrierOp>( 572 converter.getCurrentLocation()); 573 break; 574 case llvm::omp::Directive::OMPD_taskwait: 575 converter.getFirOpBuilder().create<mlir::omp::TaskwaitOp>( 576 converter.getCurrentLocation()); 577 break; 578 case llvm::omp::Directive::OMPD_taskyield: 579 converter.getFirOpBuilder().create<mlir::omp::TaskyieldOp>( 580 converter.getCurrentLocation()); 581 break; 582 case llvm::omp::Directive::OMPD_target_enter_data: 583 TODO(converter.getCurrentLocation(), "OMPD_target_enter_data"); 584 case llvm::omp::Directive::OMPD_target_exit_data: 585 TODO(converter.getCurrentLocation(), "OMPD_target_exit_data"); 586 case llvm::omp::Directive::OMPD_target_update: 587 TODO(converter.getCurrentLocation(), "OMPD_target_update"); 588 case llvm::omp::Directive::OMPD_ordered: 589 TODO(converter.getCurrentLocation(), "OMPD_ordered"); 590 } 591 } 592 593 static void 594 genAllocateClause(Fortran::lower::AbstractConverter &converter, 595 const Fortran::parser::OmpAllocateClause &ompAllocateClause, 596 SmallVector<Value> &allocatorOperands, 597 SmallVector<Value> &allocateOperands) { 598 auto &firOpBuilder = converter.getFirOpBuilder(); 599 auto currentLocation = converter.getCurrentLocation(); 600 Fortran::lower::StatementContext stmtCtx; 601 602 mlir::Value allocatorOperand; 603 const Fortran::parser::OmpObjectList &ompObjectList = 604 std::get<Fortran::parser::OmpObjectList>(ompAllocateClause.t); 605 const auto &allocatorValue = 606 std::get<std::optional<Fortran::parser::OmpAllocateClause::Allocator>>( 607 ompAllocateClause.t); 608 // Check if allocate clause has allocator specified. If so, add it 609 // to list of allocators, otherwise, add default allocator to 610 // list of allocators. 611 if (allocatorValue) { 612 allocatorOperand = fir::getBase(converter.genExprValue( 613 *Fortran::semantics::GetExpr(allocatorValue->v), stmtCtx)); 614 allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(), 615 allocatorOperand); 616 } else { 617 allocatorOperand = firOpBuilder.createIntegerConstant( 618 currentLocation, firOpBuilder.getI32Type(), 1); 619 allocatorOperands.insert(allocatorOperands.end(), ompObjectList.v.size(), 620 allocatorOperand); 621 } 622 genObjectList(ompObjectList, converter, allocateOperands); 623 } 624 625 static void 626 genOMP(Fortran::lower::AbstractConverter &converter, 627 Fortran::lower::pft::Evaluation &eval, 628 const Fortran::parser::OpenMPStandaloneConstruct &standaloneConstruct) { 629 std::visit( 630 Fortran::common::visitors{ 631 [&](const Fortran::parser::OpenMPSimpleStandaloneConstruct 632 &simpleStandaloneConstruct) { 633 genOMP(converter, eval, simpleStandaloneConstruct); 634 }, 635 [&](const Fortran::parser::OpenMPFlushConstruct &flushConstruct) { 636 SmallVector<Value, 4> operandRange; 637 if (const auto &ompObjectList = 638 std::get<std::optional<Fortran::parser::OmpObjectList>>( 639 flushConstruct.t)) 640 genObjectList(*ompObjectList, converter, operandRange); 641 const auto &memOrderClause = std::get<std::optional< 642 std::list<Fortran::parser::OmpMemoryOrderClause>>>( 643 flushConstruct.t); 644 if (memOrderClause.has_value() && memOrderClause->size() > 0) 645 TODO(converter.getCurrentLocation(), 646 "Handle OmpMemoryOrderClause"); 647 converter.getFirOpBuilder().create<mlir::omp::FlushOp>( 648 converter.getCurrentLocation(), operandRange); 649 }, 650 [&](const Fortran::parser::OpenMPCancelConstruct &cancelConstruct) { 651 TODO(converter.getCurrentLocation(), "OpenMPCancelConstruct"); 652 }, 653 [&](const Fortran::parser::OpenMPCancellationPointConstruct 654 &cancellationPointConstruct) { 655 TODO(converter.getCurrentLocation(), "OpenMPCancelConstruct"); 656 }, 657 }, 658 standaloneConstruct.u); 659 } 660 661 static omp::ClauseProcBindKindAttr genProcBindKindAttr( 662 fir::FirOpBuilder &firOpBuilder, 663 const Fortran::parser::OmpClause::ProcBind *procBindClause) { 664 omp::ClauseProcBindKind pbKind; 665 switch (procBindClause->v.v) { 666 case Fortran::parser::OmpProcBindClause::Type::Master: 667 pbKind = omp::ClauseProcBindKind::Master; 668 break; 669 case Fortran::parser::OmpProcBindClause::Type::Close: 670 pbKind = omp::ClauseProcBindKind::Close; 671 break; 672 case Fortran::parser::OmpProcBindClause::Type::Spread: 673 pbKind = omp::ClauseProcBindKind::Spread; 674 break; 675 case Fortran::parser::OmpProcBindClause::Type::Primary: 676 pbKind = omp::ClauseProcBindKind::Primary; 677 break; 678 } 679 return omp::ClauseProcBindKindAttr::get(firOpBuilder.getContext(), pbKind); 680 } 681 682 static mlir::Value 683 getIfClauseOperand(Fortran::lower::AbstractConverter &converter, 684 Fortran::lower::StatementContext &stmtCtx, 685 const Fortran::parser::OmpClause::If *ifClause) { 686 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 687 mlir::Location currentLocation = converter.getCurrentLocation(); 688 auto &expr = std::get<Fortran::parser::ScalarLogicalExpr>(ifClause->v.t); 689 mlir::Value ifVal = fir::getBase( 690 converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx)); 691 return firOpBuilder.createConvert(currentLocation, firOpBuilder.getI1Type(), 692 ifVal); 693 } 694 695 /* When parallel is used in a combined construct, then use this function to 696 * create the parallel operation. It handles the parallel specific clauses 697 * and leaves the rest for handling at the inner operations. 698 * TODO: Refactor clause handling 699 */ 700 template <typename Directive> 701 static void 702 createCombinedParallelOp(Fortran::lower::AbstractConverter &converter, 703 Fortran::lower::pft::Evaluation &eval, 704 const Directive &directive) { 705 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 706 mlir::Location currentLocation = converter.getCurrentLocation(); 707 Fortran::lower::StatementContext stmtCtx; 708 llvm::ArrayRef<mlir::Type> argTy; 709 mlir::Value ifClauseOperand, numThreadsClauseOperand; 710 SmallVector<Value> allocatorOperands, allocateOperands; 711 mlir::omp::ClauseProcBindKindAttr procBindKindAttr; 712 const auto &opClauseList = 713 std::get<Fortran::parser::OmpClauseList>(directive.t); 714 // TODO: Handle the following clauses 715 // 1. default 716 // Note: rest of the clauses are handled when the inner operation is created 717 for (const Fortran::parser::OmpClause &clause : opClauseList.v) { 718 if (const auto &ifClause = 719 std::get_if<Fortran::parser::OmpClause::If>(&clause.u)) { 720 ifClauseOperand = getIfClauseOperand(converter, stmtCtx, ifClause); 721 } else if (const auto &numThreadsClause = 722 std::get_if<Fortran::parser::OmpClause::NumThreads>( 723 &clause.u)) { 724 numThreadsClauseOperand = fir::getBase(converter.genExprValue( 725 *Fortran::semantics::GetExpr(numThreadsClause->v), stmtCtx)); 726 } else if (const auto &procBindClause = 727 std::get_if<Fortran::parser::OmpClause::ProcBind>( 728 &clause.u)) { 729 procBindKindAttr = genProcBindKindAttr(firOpBuilder, procBindClause); 730 } 731 } 732 // Create and insert the operation. 733 auto parallelOp = firOpBuilder.create<mlir::omp::ParallelOp>( 734 currentLocation, argTy, ifClauseOperand, numThreadsClauseOperand, 735 allocateOperands, allocatorOperands, /*reduction_vars=*/ValueRange(), 736 /*reductions=*/nullptr, procBindKindAttr); 737 738 createBodyOfOp<omp::ParallelOp>(parallelOp, converter, currentLocation, eval, 739 &opClauseList, /*iv=*/{}, 740 /*isCombined=*/true); 741 } 742 743 static void 744 genOMP(Fortran::lower::AbstractConverter &converter, 745 Fortran::lower::pft::Evaluation &eval, 746 const Fortran::parser::OpenMPBlockConstruct &blockConstruct) { 747 const auto &beginBlockDirective = 748 std::get<Fortran::parser::OmpBeginBlockDirective>(blockConstruct.t); 749 const auto &blockDirective = 750 std::get<Fortran::parser::OmpBlockDirective>(beginBlockDirective.t); 751 const auto &endBlockDirective = 752 std::get<Fortran::parser::OmpEndBlockDirective>(blockConstruct.t); 753 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 754 mlir::Location currentLocation = converter.getCurrentLocation(); 755 756 Fortran::lower::StatementContext stmtCtx; 757 llvm::ArrayRef<mlir::Type> argTy; 758 mlir::Value ifClauseOperand, numThreadsClauseOperand, finalClauseOperand, 759 priorityClauseOperand; 760 mlir::omp::ClauseProcBindKindAttr procBindKindAttr; 761 SmallVector<Value> allocateOperands, allocatorOperands; 762 mlir::UnitAttr nowaitAttr, untiedAttr, mergeableAttr; 763 764 const auto &opClauseList = 765 std::get<Fortran::parser::OmpClauseList>(beginBlockDirective.t); 766 for (const auto &clause : opClauseList.v) { 767 if (const auto &ifClause = 768 std::get_if<Fortran::parser::OmpClause::If>(&clause.u)) { 769 ifClauseOperand = getIfClauseOperand(converter, stmtCtx, ifClause); 770 } else if (const auto &numThreadsClause = 771 std::get_if<Fortran::parser::OmpClause::NumThreads>( 772 &clause.u)) { 773 // OMPIRBuilder expects `NUM_THREAD` clause as a `Value`. 774 numThreadsClauseOperand = fir::getBase(converter.genExprValue( 775 *Fortran::semantics::GetExpr(numThreadsClause->v), stmtCtx)); 776 } else if (const auto &procBindClause = 777 std::get_if<Fortran::parser::OmpClause::ProcBind>( 778 &clause.u)) { 779 procBindKindAttr = genProcBindKindAttr(firOpBuilder, procBindClause); 780 } else if (const auto &allocateClause = 781 std::get_if<Fortran::parser::OmpClause::Allocate>( 782 &clause.u)) { 783 genAllocateClause(converter, allocateClause->v, allocatorOperands, 784 allocateOperands); 785 } else if (std::get_if<Fortran::parser::OmpClause::Private>(&clause.u) || 786 std::get_if<Fortran::parser::OmpClause::Firstprivate>( 787 &clause.u) || 788 std::get_if<Fortran::parser::OmpClause::Copyin>(&clause.u)) { 789 // Privatisation and copyin clauses are handled elsewhere. 790 continue; 791 } else if (std::get_if<Fortran::parser::OmpClause::Shared>(&clause.u)) { 792 // Shared is the default behavior in the IR, so no handling is required. 793 continue; 794 } else if (const auto &defaultClause = 795 std::get_if<Fortran::parser::OmpClause::Default>( 796 &clause.u)) { 797 if ((defaultClause->v.v == 798 Fortran::parser::OmpDefaultClause::Type::Shared) || 799 (defaultClause->v.v == 800 Fortran::parser::OmpDefaultClause::Type::None)) { 801 // Default clause with shared or none do not require any handling since 802 // Shared is the default behavior in the IR and None is only required 803 // for semantic checks. 804 continue; 805 } 806 } else if (std::get_if<Fortran::parser::OmpClause::Threads>(&clause.u)) { 807 // Nothing needs to be done for threads clause. 808 continue; 809 } else if (const auto &finalClause = 810 std::get_if<Fortran::parser::OmpClause::Final>(&clause.u)) { 811 mlir::Value finalVal = fir::getBase(converter.genExprValue( 812 *Fortran::semantics::GetExpr(finalClause->v), stmtCtx)); 813 finalClauseOperand = firOpBuilder.createConvert( 814 currentLocation, firOpBuilder.getI1Type(), finalVal); 815 } else if (std::get_if<Fortran::parser::OmpClause::Untied>(&clause.u)) { 816 untiedAttr = firOpBuilder.getUnitAttr(); 817 } else if (std::get_if<Fortran::parser::OmpClause::Mergeable>(&clause.u)) { 818 mergeableAttr = firOpBuilder.getUnitAttr(); 819 } else if (const auto &priorityClause = 820 std::get_if<Fortran::parser::OmpClause::Priority>( 821 &clause.u)) { 822 priorityClauseOperand = fir::getBase(converter.genExprValue( 823 *Fortran::semantics::GetExpr(priorityClause->v), stmtCtx)); 824 } else { 825 TODO(currentLocation, "OpenMP Block construct clauses"); 826 } 827 } 828 829 for (const auto &clause : 830 std::get<Fortran::parser::OmpClauseList>(endBlockDirective.t).v) { 831 if (std::get_if<Fortran::parser::OmpClause::Nowait>(&clause.u)) 832 nowaitAttr = firOpBuilder.getUnitAttr(); 833 } 834 835 if (blockDirective.v == llvm::omp::OMPD_parallel) { 836 // Create and insert the operation. 837 auto parallelOp = firOpBuilder.create<mlir::omp::ParallelOp>( 838 currentLocation, argTy, ifClauseOperand, numThreadsClauseOperand, 839 allocateOperands, allocatorOperands, /*reduction_vars=*/ValueRange(), 840 /*reductions=*/nullptr, procBindKindAttr); 841 createBodyOfOp<omp::ParallelOp>(parallelOp, converter, currentLocation, 842 eval, &opClauseList); 843 } else if (blockDirective.v == llvm::omp::OMPD_master) { 844 auto masterOp = 845 firOpBuilder.create<mlir::omp::MasterOp>(currentLocation, argTy); 846 createBodyOfOp<omp::MasterOp>(masterOp, converter, currentLocation, eval); 847 } else if (blockDirective.v == llvm::omp::OMPD_single) { 848 auto singleOp = firOpBuilder.create<mlir::omp::SingleOp>( 849 currentLocation, allocateOperands, allocatorOperands, nowaitAttr); 850 createBodyOfOp<omp::SingleOp>(singleOp, converter, currentLocation, eval); 851 } else if (blockDirective.v == llvm::omp::OMPD_ordered) { 852 auto orderedOp = firOpBuilder.create<mlir::omp::OrderedRegionOp>( 853 currentLocation, /*simd=*/nullptr); 854 createBodyOfOp<omp::OrderedRegionOp>(orderedOp, converter, currentLocation, 855 eval); 856 } else if (blockDirective.v == llvm::omp::OMPD_task) { 857 auto taskOp = firOpBuilder.create<mlir::omp::TaskOp>( 858 currentLocation, ifClauseOperand, finalClauseOperand, untiedAttr, 859 mergeableAttr, /*in_reduction_vars=*/ValueRange(), 860 /*in_reductions=*/nullptr, priorityClauseOperand, allocateOperands, 861 allocatorOperands); 862 createBodyOfOp(taskOp, converter, currentLocation, eval, &opClauseList); 863 } else { 864 TODO(converter.getCurrentLocation(), "Unhandled block directive"); 865 } 866 } 867 868 /// Creates an OpenMP reduction declaration and inserts it into the provided 869 /// symbol table. The declaration has a constant initializer with the neutral 870 /// value `initValue`, and the reduction combiner carried over from `reduce`. 871 /// TODO: Generalize this for non-integer types, add atomic region. 872 static omp::ReductionDeclareOp createReductionDecl(fir::FirOpBuilder &builder, 873 llvm::StringRef name, 874 mlir::Type type, 875 mlir::Location loc) { 876 OpBuilder::InsertionGuard guard(builder); 877 mlir::ModuleOp module = builder.getModule(); 878 mlir::OpBuilder modBuilder(module.getBodyRegion()); 879 auto decl = module.lookupSymbol<mlir::omp::ReductionDeclareOp>(name); 880 if (!decl) 881 decl = modBuilder.create<omp::ReductionDeclareOp>(loc, name, type); 882 else 883 return decl; 884 885 builder.createBlock(&decl.initializerRegion(), decl.initializerRegion().end(), 886 {type}, {loc}); 887 builder.setInsertionPointToEnd(&decl.initializerRegion().back()); 888 Value init = builder.create<mlir::arith::ConstantOp>( 889 loc, type, builder.getIntegerAttr(type, 0)); 890 builder.create<omp::YieldOp>(loc, init); 891 892 builder.createBlock(&decl.reductionRegion(), decl.reductionRegion().end(), 893 {type, type}, {loc, loc}); 894 builder.setInsertionPointToEnd(&decl.reductionRegion().back()); 895 mlir::Value op1 = decl.reductionRegion().front().getArgument(0); 896 mlir::Value op2 = decl.reductionRegion().front().getArgument(1); 897 Value addRes = builder.create<mlir::arith::AddIOp>(loc, op1, op2); 898 builder.create<omp::YieldOp>(loc, addRes); 899 return decl; 900 } 901 902 static mlir::omp::ScheduleModifier 903 translateModifier(const Fortran::parser::OmpScheduleModifierType &m) { 904 switch (m.v) { 905 case Fortran::parser::OmpScheduleModifierType::ModType::Monotonic: 906 return mlir::omp::ScheduleModifier::monotonic; 907 case Fortran::parser::OmpScheduleModifierType::ModType::Nonmonotonic: 908 return mlir::omp::ScheduleModifier::nonmonotonic; 909 case Fortran::parser::OmpScheduleModifierType::ModType::Simd: 910 return mlir::omp::ScheduleModifier::simd; 911 } 912 return mlir::omp::ScheduleModifier::none; 913 } 914 915 static mlir::omp::ScheduleModifier 916 getScheduleModifier(const Fortran::parser::OmpScheduleClause &x) { 917 const auto &modifier = 918 std::get<std::optional<Fortran::parser::OmpScheduleModifier>>(x.t); 919 // The input may have the modifier any order, so we look for one that isn't 920 // SIMD. If modifier is not set at all, fall down to the bottom and return 921 // "none". 922 if (modifier) { 923 const auto &modType1 = 924 std::get<Fortran::parser::OmpScheduleModifier::Modifier1>(modifier->t); 925 if (modType1.v.v == 926 Fortran::parser::OmpScheduleModifierType::ModType::Simd) { 927 const auto &modType2 = std::get< 928 std::optional<Fortran::parser::OmpScheduleModifier::Modifier2>>( 929 modifier->t); 930 if (modType2 && 931 modType2->v.v != 932 Fortran::parser::OmpScheduleModifierType::ModType::Simd) 933 return translateModifier(modType2->v); 934 935 return mlir::omp::ScheduleModifier::none; 936 } 937 938 return translateModifier(modType1.v); 939 } 940 return mlir::omp::ScheduleModifier::none; 941 } 942 943 static mlir::omp::ScheduleModifier 944 getSIMDModifier(const Fortran::parser::OmpScheduleClause &x) { 945 const auto &modifier = 946 std::get<std::optional<Fortran::parser::OmpScheduleModifier>>(x.t); 947 // Either of the two possible modifiers in the input can be the SIMD modifier, 948 // so look in either one, and return simd if we find one. Not found = return 949 // "none". 950 if (modifier) { 951 const auto &modType1 = 952 std::get<Fortran::parser::OmpScheduleModifier::Modifier1>(modifier->t); 953 if (modType1.v.v == Fortran::parser::OmpScheduleModifierType::ModType::Simd) 954 return mlir::omp::ScheduleModifier::simd; 955 956 const auto &modType2 = std::get< 957 std::optional<Fortran::parser::OmpScheduleModifier::Modifier2>>( 958 modifier->t); 959 if (modType2 && modType2->v.v == 960 Fortran::parser::OmpScheduleModifierType::ModType::Simd) 961 return mlir::omp::ScheduleModifier::simd; 962 } 963 return mlir::omp::ScheduleModifier::none; 964 } 965 966 static std::string getReductionName( 967 Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp, 968 mlir::Type ty) { 969 std::string reductionName; 970 if (intrinsicOp == Fortran::parser::DefinedOperator::IntrinsicOperator::Add) 971 reductionName = "add_reduction"; 972 else 973 reductionName = "other_reduction"; 974 975 return (llvm::Twine(reductionName) + 976 (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) + 977 llvm::Twine(ty.getIntOrFloatBitWidth())) 978 .str(); 979 } 980 981 static void genOMP(Fortran::lower::AbstractConverter &converter, 982 Fortran::lower::pft::Evaluation &eval, 983 const Fortran::parser::OpenMPLoopConstruct &loopConstruct) { 984 985 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 986 mlir::Location currentLocation = converter.getCurrentLocation(); 987 llvm::SmallVector<mlir::Value> lowerBound, upperBound, step, linearVars, 988 linearStepVars, reductionVars; 989 mlir::Value scheduleChunkClauseOperand, ifClauseOperand; 990 mlir::Attribute scheduleClauseOperand, noWaitClauseOperand, 991 orderedClauseOperand, orderClauseOperand; 992 SmallVector<Attribute> reductionDeclSymbols; 993 Fortran::lower::StatementContext stmtCtx; 994 const auto &loopOpClauseList = std::get<Fortran::parser::OmpClauseList>( 995 std::get<Fortran::parser::OmpBeginLoopDirective>(loopConstruct.t).t); 996 997 const auto ompDirective = 998 std::get<Fortran::parser::OmpLoopDirective>( 999 std::get<Fortran::parser::OmpBeginLoopDirective>(loopConstruct.t).t) 1000 .v; 1001 if (llvm::omp::OMPD_parallel_do == ompDirective) { 1002 createCombinedParallelOp<Fortran::parser::OmpBeginLoopDirective>( 1003 converter, eval, 1004 std::get<Fortran::parser::OmpBeginLoopDirective>(loopConstruct.t)); 1005 } else if (llvm::omp::OMPD_do != ompDirective && 1006 llvm::omp::OMPD_simd != ompDirective) { 1007 TODO(converter.getCurrentLocation(), "Construct enclosing do loop"); 1008 } 1009 1010 // Collect the loops to collapse. 1011 auto *doConstructEval = &eval.getFirstNestedEvaluation(); 1012 1013 std::int64_t collapseValue = 1014 Fortran::lower::getCollapseValue(loopOpClauseList); 1015 std::size_t loopVarTypeSize = 0; 1016 SmallVector<const Fortran::semantics::Symbol *> iv; 1017 do { 1018 auto *doLoop = &doConstructEval->getFirstNestedEvaluation(); 1019 auto *doStmt = doLoop->getIf<Fortran::parser::NonLabelDoStmt>(); 1020 assert(doStmt && "Expected do loop to be in the nested evaluation"); 1021 const auto &loopControl = 1022 std::get<std::optional<Fortran::parser::LoopControl>>(doStmt->t); 1023 const Fortran::parser::LoopControl::Bounds *bounds = 1024 std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u); 1025 assert(bounds && "Expected bounds for worksharing do loop"); 1026 Fortran::lower::StatementContext stmtCtx; 1027 lowerBound.push_back(fir::getBase(converter.genExprValue( 1028 *Fortran::semantics::GetExpr(bounds->lower), stmtCtx))); 1029 upperBound.push_back(fir::getBase(converter.genExprValue( 1030 *Fortran::semantics::GetExpr(bounds->upper), stmtCtx))); 1031 if (bounds->step) { 1032 step.push_back(fir::getBase(converter.genExprValue( 1033 *Fortran::semantics::GetExpr(bounds->step), stmtCtx))); 1034 } else { // If `step` is not present, assume it as `1`. 1035 step.push_back(firOpBuilder.createIntegerConstant( 1036 currentLocation, firOpBuilder.getIntegerType(32), 1)); 1037 } 1038 iv.push_back(bounds->name.thing.symbol); 1039 loopVarTypeSize = std::max(loopVarTypeSize, 1040 bounds->name.thing.symbol->GetUltimate().size()); 1041 1042 collapseValue--; 1043 doConstructEval = 1044 &*std::next(doConstructEval->getNestedEvaluations().begin()); 1045 } while (collapseValue > 0); 1046 1047 for (const auto &clause : loopOpClauseList.v) { 1048 if (const auto &scheduleClause = 1049 std::get_if<Fortran::parser::OmpClause::Schedule>(&clause.u)) { 1050 if (const auto &chunkExpr = 1051 std::get<std::optional<Fortran::parser::ScalarIntExpr>>( 1052 scheduleClause->v.t)) { 1053 if (const auto *expr = Fortran::semantics::GetExpr(*chunkExpr)) { 1054 scheduleChunkClauseOperand = 1055 fir::getBase(converter.genExprValue(*expr, stmtCtx)); 1056 } 1057 } 1058 } else if (const auto &ifClause = 1059 std::get_if<Fortran::parser::OmpClause::If>(&clause.u)) { 1060 ifClauseOperand = getIfClauseOperand(converter, stmtCtx, ifClause); 1061 } else if (const auto &reductionClause = 1062 std::get_if<Fortran::parser::OmpClause::Reduction>( 1063 &clause.u)) { 1064 omp::ReductionDeclareOp decl; 1065 const auto &redOperator{std::get<Fortran::parser::OmpReductionOperator>( 1066 reductionClause->v.t)}; 1067 const auto &objectList{ 1068 std::get<Fortran::parser::OmpObjectList>(reductionClause->v.t)}; 1069 if (const auto &redDefinedOp = 1070 std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) { 1071 const auto &intrinsicOp{ 1072 std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>( 1073 redDefinedOp->u)}; 1074 if (intrinsicOp != 1075 Fortran::parser::DefinedOperator::IntrinsicOperator::Add) 1076 TODO(currentLocation, 1077 "Reduction of some intrinsic operators is not supported"); 1078 for (const auto &ompObject : objectList.v) { 1079 if (const auto *name{ 1080 Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) { 1081 if (const auto *symbol{name->symbol}) { 1082 mlir::Value symVal = converter.getSymbolAddress(*symbol); 1083 mlir::Type redType = 1084 symVal.getType().cast<fir::ReferenceType>().getEleTy(); 1085 reductionVars.push_back(symVal); 1086 if (redType.isIntOrIndex()) { 1087 decl = createReductionDecl( 1088 firOpBuilder, getReductionName(intrinsicOp, redType), 1089 redType, currentLocation); 1090 } else { 1091 TODO(currentLocation, 1092 "Reduction of some types is not supported"); 1093 } 1094 reductionDeclSymbols.push_back(SymbolRefAttr::get( 1095 firOpBuilder.getContext(), decl.sym_name())); 1096 } 1097 } 1098 } 1099 } else { 1100 TODO(currentLocation, 1101 "Reduction of intrinsic procedures is not supported"); 1102 } 1103 } 1104 } 1105 1106 // The types of lower bound, upper bound, and step are converted into the 1107 // type of the loop variable if necessary. 1108 mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize); 1109 for (unsigned it = 0; it < (unsigned)lowerBound.size(); it++) { 1110 lowerBound[it] = firOpBuilder.createConvert(currentLocation, loopVarType, 1111 lowerBound[it]); 1112 upperBound[it] = firOpBuilder.createConvert(currentLocation, loopVarType, 1113 upperBound[it]); 1114 step[it] = 1115 firOpBuilder.createConvert(currentLocation, loopVarType, step[it]); 1116 } 1117 1118 // 2.9.3.1 SIMD construct 1119 // TODO: Support all the clauses 1120 if (llvm::omp::OMPD_simd == ompDirective) { 1121 TypeRange resultType; 1122 auto SimdLoopOp = firOpBuilder.create<mlir::omp::SimdLoopOp>( 1123 currentLocation, resultType, lowerBound, upperBound, step, 1124 ifClauseOperand, /*inclusive=*/firOpBuilder.getUnitAttr()); 1125 createBodyOfOp<omp::SimdLoopOp>(SimdLoopOp, converter, currentLocation, 1126 eval, &loopOpClauseList, iv); 1127 return; 1128 } 1129 1130 // FIXME: Add support for following clauses: 1131 // 1. linear 1132 // 2. order 1133 auto wsLoopOp = firOpBuilder.create<mlir::omp::WsLoopOp>( 1134 currentLocation, lowerBound, upperBound, step, linearVars, linearStepVars, 1135 reductionVars, 1136 reductionDeclSymbols.empty() 1137 ? nullptr 1138 : mlir::ArrayAttr::get(firOpBuilder.getContext(), 1139 reductionDeclSymbols), 1140 scheduleClauseOperand.dyn_cast_or_null<omp::ClauseScheduleKindAttr>(), 1141 scheduleChunkClauseOperand, /*schedule_modifiers=*/nullptr, 1142 /*simd_modifier=*/nullptr, 1143 noWaitClauseOperand.dyn_cast_or_null<UnitAttr>(), 1144 orderedClauseOperand.dyn_cast_or_null<IntegerAttr>(), 1145 orderClauseOperand.dyn_cast_or_null<omp::ClauseOrderKindAttr>(), 1146 /*inclusive=*/firOpBuilder.getUnitAttr()); 1147 1148 // Handle attribute based clauses. 1149 for (const Fortran::parser::OmpClause &clause : loopOpClauseList.v) { 1150 if (const auto &orderedClause = 1151 std::get_if<Fortran::parser::OmpClause::Ordered>(&clause.u)) { 1152 if (orderedClause->v.has_value()) { 1153 const auto *expr = Fortran::semantics::GetExpr(orderedClause->v); 1154 const std::optional<std::int64_t> orderedClauseValue = 1155 Fortran::evaluate::ToInt64(*expr); 1156 wsLoopOp.ordered_valAttr( 1157 firOpBuilder.getI64IntegerAttr(*orderedClauseValue)); 1158 } else { 1159 wsLoopOp.ordered_valAttr(firOpBuilder.getI64IntegerAttr(0)); 1160 } 1161 } else if (const auto &scheduleClause = 1162 std::get_if<Fortran::parser::OmpClause::Schedule>( 1163 &clause.u)) { 1164 mlir::MLIRContext *context = firOpBuilder.getContext(); 1165 const auto &scheduleType = scheduleClause->v; 1166 const auto &scheduleKind = 1167 std::get<Fortran::parser::OmpScheduleClause::ScheduleType>( 1168 scheduleType.t); 1169 switch (scheduleKind) { 1170 case Fortran::parser::OmpScheduleClause::ScheduleType::Static: 1171 wsLoopOp.schedule_valAttr(omp::ClauseScheduleKindAttr::get( 1172 context, omp::ClauseScheduleKind::Static)); 1173 break; 1174 case Fortran::parser::OmpScheduleClause::ScheduleType::Dynamic: 1175 wsLoopOp.schedule_valAttr(omp::ClauseScheduleKindAttr::get( 1176 context, omp::ClauseScheduleKind::Dynamic)); 1177 break; 1178 case Fortran::parser::OmpScheduleClause::ScheduleType::Guided: 1179 wsLoopOp.schedule_valAttr(omp::ClauseScheduleKindAttr::get( 1180 context, omp::ClauseScheduleKind::Guided)); 1181 break; 1182 case Fortran::parser::OmpScheduleClause::ScheduleType::Auto: 1183 wsLoopOp.schedule_valAttr(omp::ClauseScheduleKindAttr::get( 1184 context, omp::ClauseScheduleKind::Auto)); 1185 break; 1186 case Fortran::parser::OmpScheduleClause::ScheduleType::Runtime: 1187 wsLoopOp.schedule_valAttr(omp::ClauseScheduleKindAttr::get( 1188 context, omp::ClauseScheduleKind::Runtime)); 1189 break; 1190 } 1191 mlir::omp::ScheduleModifier scheduleModifier = 1192 getScheduleModifier(scheduleClause->v); 1193 if (scheduleModifier != mlir::omp::ScheduleModifier::none) 1194 wsLoopOp.schedule_modifierAttr( 1195 omp::ScheduleModifierAttr::get(context, scheduleModifier)); 1196 if (getSIMDModifier(scheduleClause->v) != 1197 mlir::omp::ScheduleModifier::none) 1198 wsLoopOp.simd_modifierAttr(firOpBuilder.getUnitAttr()); 1199 } 1200 } 1201 // In FORTRAN `nowait` clause occur at the end of `omp do` directive. 1202 // i.e 1203 // !$omp do 1204 // <...> 1205 // !$omp end do nowait 1206 if (const auto &endClauseList = 1207 std::get<std::optional<Fortran::parser::OmpEndLoopDirective>>( 1208 loopConstruct.t)) { 1209 const auto &clauseList = 1210 std::get<Fortran::parser::OmpClauseList>((*endClauseList).t); 1211 for (const Fortran::parser::OmpClause &clause : clauseList.v) 1212 if (std::get_if<Fortran::parser::OmpClause::Nowait>(&clause.u)) 1213 wsLoopOp.nowaitAttr(firOpBuilder.getUnitAttr()); 1214 } 1215 1216 createBodyOfOp<omp::WsLoopOp>(wsLoopOp, converter, currentLocation, eval, 1217 &loopOpClauseList, iv); 1218 } 1219 1220 static void 1221 genOMP(Fortran::lower::AbstractConverter &converter, 1222 Fortran::lower::pft::Evaluation &eval, 1223 const Fortran::parser::OpenMPCriticalConstruct &criticalConstruct) { 1224 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 1225 mlir::Location currentLocation = converter.getCurrentLocation(); 1226 std::string name; 1227 const Fortran::parser::OmpCriticalDirective &cd = 1228 std::get<Fortran::parser::OmpCriticalDirective>(criticalConstruct.t); 1229 if (std::get<std::optional<Fortran::parser::Name>>(cd.t).has_value()) { 1230 name = 1231 std::get<std::optional<Fortran::parser::Name>>(cd.t).value().ToString(); 1232 } 1233 1234 uint64_t hint = 0; 1235 const auto &clauseList = std::get<Fortran::parser::OmpClauseList>(cd.t); 1236 for (const Fortran::parser::OmpClause &clause : clauseList.v) 1237 if (auto hintClause = 1238 std::get_if<Fortran::parser::OmpClause::Hint>(&clause.u)) { 1239 const auto *expr = Fortran::semantics::GetExpr(hintClause->v); 1240 hint = *Fortran::evaluate::ToInt64(*expr); 1241 break; 1242 } 1243 1244 mlir::omp::CriticalOp criticalOp = [&]() { 1245 if (name.empty()) { 1246 return firOpBuilder.create<mlir::omp::CriticalOp>(currentLocation, 1247 FlatSymbolRefAttr()); 1248 } else { 1249 mlir::ModuleOp module = firOpBuilder.getModule(); 1250 mlir::OpBuilder modBuilder(module.getBodyRegion()); 1251 auto global = module.lookupSymbol<mlir::omp::CriticalDeclareOp>(name); 1252 if (!global) 1253 global = modBuilder.create<mlir::omp::CriticalDeclareOp>( 1254 currentLocation, name, hint); 1255 return firOpBuilder.create<mlir::omp::CriticalOp>( 1256 currentLocation, mlir::FlatSymbolRefAttr::get( 1257 firOpBuilder.getContext(), global.sym_name())); 1258 } 1259 }(); 1260 createBodyOfOp<omp::CriticalOp>(criticalOp, converter, currentLocation, eval); 1261 } 1262 1263 static void 1264 genOMP(Fortran::lower::AbstractConverter &converter, 1265 Fortran::lower::pft::Evaluation &eval, 1266 const Fortran::parser::OpenMPSectionConstruct §ionConstruct) { 1267 1268 auto &firOpBuilder = converter.getFirOpBuilder(); 1269 auto currentLocation = converter.getCurrentLocation(); 1270 mlir::omp::SectionOp sectionOp = 1271 firOpBuilder.create<mlir::omp::SectionOp>(currentLocation); 1272 createBodyOfOp<omp::SectionOp>(sectionOp, converter, currentLocation, eval); 1273 } 1274 1275 // TODO: Add support for reduction 1276 static void 1277 genOMP(Fortran::lower::AbstractConverter &converter, 1278 Fortran::lower::pft::Evaluation &eval, 1279 const Fortran::parser::OpenMPSectionsConstruct §ionsConstruct) { 1280 auto &firOpBuilder = converter.getFirOpBuilder(); 1281 auto currentLocation = converter.getCurrentLocation(); 1282 SmallVector<Value> reductionVars, allocateOperands, allocatorOperands; 1283 mlir::UnitAttr noWaitClauseOperand; 1284 const auto §ionsClauseList = std::get<Fortran::parser::OmpClauseList>( 1285 std::get<Fortran::parser::OmpBeginSectionsDirective>(sectionsConstruct.t) 1286 .t); 1287 for (const Fortran::parser::OmpClause &clause : sectionsClauseList.v) { 1288 1289 // Reduction Clause 1290 if (std::get_if<Fortran::parser::OmpClause::Reduction>(&clause.u)) { 1291 TODO(currentLocation, "OMPC_Reduction"); 1292 1293 // Allocate clause 1294 } else if (const auto &allocateClause = 1295 std::get_if<Fortran::parser::OmpClause::Allocate>( 1296 &clause.u)) { 1297 genAllocateClause(converter, allocateClause->v, allocatorOperands, 1298 allocateOperands); 1299 } 1300 } 1301 const auto &endSectionsClauseList = 1302 std::get<Fortran::parser::OmpEndSectionsDirective>(sectionsConstruct.t); 1303 const auto &clauseList = 1304 std::get<Fortran::parser::OmpClauseList>(endSectionsClauseList.t); 1305 for (const auto &clause : clauseList.v) { 1306 // Nowait clause 1307 if (std::get_if<Fortran::parser::OmpClause::Nowait>(&clause.u)) { 1308 noWaitClauseOperand = firOpBuilder.getUnitAttr(); 1309 } 1310 } 1311 1312 llvm::omp::Directive dir = 1313 std::get<Fortran::parser::OmpSectionsDirective>( 1314 std::get<Fortran::parser::OmpBeginSectionsDirective>( 1315 sectionsConstruct.t) 1316 .t) 1317 .v; 1318 1319 // Parallel Sections Construct 1320 if (dir == llvm::omp::Directive::OMPD_parallel_sections) { 1321 createCombinedParallelOp<Fortran::parser::OmpBeginSectionsDirective>( 1322 converter, eval, 1323 std::get<Fortran::parser::OmpBeginSectionsDirective>( 1324 sectionsConstruct.t)); 1325 auto sectionsOp = firOpBuilder.create<mlir::omp::SectionsOp>( 1326 currentLocation, /*reduction_vars*/ ValueRange(), 1327 /*reductions=*/nullptr, allocateOperands, allocatorOperands, 1328 /*nowait=*/nullptr); 1329 createBodyOfOp(sectionsOp, converter, currentLocation, eval); 1330 1331 // Sections Construct 1332 } else if (dir == llvm::omp::Directive::OMPD_sections) { 1333 auto sectionsOp = firOpBuilder.create<mlir::omp::SectionsOp>( 1334 currentLocation, reductionVars, /*reductions = */ nullptr, 1335 allocateOperands, allocatorOperands, noWaitClauseOperand); 1336 createBodyOfOp<omp::SectionsOp>(sectionsOp, converter, currentLocation, 1337 eval); 1338 } 1339 } 1340 1341 static void genOmpAtomicHintAndMemoryOrderClauses( 1342 Fortran::lower::AbstractConverter &converter, 1343 const Fortran::parser::OmpAtomicClauseList &clauseList, 1344 mlir::IntegerAttr &hint, 1345 mlir::omp::ClauseMemoryOrderKindAttr &memory_order) { 1346 auto &firOpBuilder = converter.getFirOpBuilder(); 1347 for (const auto &clause : clauseList.v) { 1348 if (auto ompClause = std::get_if<Fortran::parser::OmpClause>(&clause.u)) { 1349 if (auto hintClause = 1350 std::get_if<Fortran::parser::OmpClause::Hint>(&ompClause->u)) { 1351 const auto *expr = Fortran::semantics::GetExpr(hintClause->v); 1352 uint64_t hintExprValue = *Fortran::evaluate::ToInt64(*expr); 1353 hint = firOpBuilder.getI64IntegerAttr(hintExprValue); 1354 } 1355 } else if (auto ompMemoryOrderClause = 1356 std::get_if<Fortran::parser::OmpMemoryOrderClause>( 1357 &clause.u)) { 1358 if (std::get_if<Fortran::parser::OmpClause::Acquire>( 1359 &ompMemoryOrderClause->v.u)) { 1360 memory_order = mlir::omp::ClauseMemoryOrderKindAttr::get( 1361 firOpBuilder.getContext(), omp::ClauseMemoryOrderKind::Acquire); 1362 } else if (std::get_if<Fortran::parser::OmpClause::Relaxed>( 1363 &ompMemoryOrderClause->v.u)) { 1364 memory_order = mlir::omp::ClauseMemoryOrderKindAttr::get( 1365 firOpBuilder.getContext(), omp::ClauseMemoryOrderKind::Relaxed); 1366 } else if (std::get_if<Fortran::parser::OmpClause::SeqCst>( 1367 &ompMemoryOrderClause->v.u)) { 1368 memory_order = mlir::omp::ClauseMemoryOrderKindAttr::get( 1369 firOpBuilder.getContext(), omp::ClauseMemoryOrderKind::Seq_cst); 1370 } else if (std::get_if<Fortran::parser::OmpClause::Release>( 1371 &ompMemoryOrderClause->v.u)) { 1372 memory_order = mlir::omp::ClauseMemoryOrderKindAttr::get( 1373 firOpBuilder.getContext(), omp::ClauseMemoryOrderKind::Release); 1374 } 1375 } 1376 } 1377 } 1378 1379 static void genOmpAtomicUpdateStatement( 1380 Fortran::lower::AbstractConverter &converter, 1381 Fortran::lower::pft::Evaluation &eval, 1382 const Fortran::parser::Variable &assignmentStmtVariable, 1383 const Fortran::parser::Expr &assignmentStmtExpr, 1384 const Fortran::parser::OmpAtomicClauseList *leftHandClauseList, 1385 const Fortran::parser::OmpAtomicClauseList *rightHandClauseList) { 1386 // Generate `omp.atomic.update` operation for atomic assignment statements 1387 auto &firOpBuilder = converter.getFirOpBuilder(); 1388 auto currentLocation = converter.getCurrentLocation(); 1389 Fortran::lower::StatementContext stmtCtx; 1390 1391 mlir::Value address = fir::getBase(converter.genExprAddr( 1392 *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx)); 1393 // If no hint clause is specified, the effect is as if 1394 // hint(omp_sync_hint_none) had been specified. 1395 mlir::IntegerAttr hint = nullptr; 1396 mlir::omp::ClauseMemoryOrderKindAttr memory_order = nullptr; 1397 if (leftHandClauseList) 1398 genOmpAtomicHintAndMemoryOrderClauses(converter, *leftHandClauseList, hint, 1399 memory_order); 1400 if (rightHandClauseList) 1401 genOmpAtomicHintAndMemoryOrderClauses(converter, *rightHandClauseList, hint, 1402 memory_order); 1403 auto atomicUpdateOp = firOpBuilder.create<mlir::omp::AtomicUpdateOp>( 1404 currentLocation, address, hint, memory_order); 1405 1406 //// Generate body of Atomic Update operation 1407 // If an argument for the region is provided then create the block with that 1408 // argument. Also update the symbol's address with the argument mlir value. 1409 mlir::Type varType = 1410 fir::getBase( 1411 converter.genExprValue( 1412 *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx)) 1413 .getType(); 1414 SmallVector<Type> varTys = {varType}; 1415 SmallVector<Location> locs = {currentLocation}; 1416 firOpBuilder.createBlock(&atomicUpdateOp.getRegion(), {}, varTys, locs); 1417 mlir::Value val = 1418 fir::getBase(atomicUpdateOp.getRegion().front().getArgument(0)); 1419 auto varDesignator = 1420 std::get_if<Fortran::common::Indirection<Fortran::parser::Designator>>( 1421 &assignmentStmtVariable.u); 1422 assert(varDesignator && "Variable designator for atomic update assignment " 1423 "statement does not exist"); 1424 const auto *name = getDesignatorNameIfDataRef(varDesignator->value()); 1425 assert(name && name->symbol && 1426 "No symbol attached to atomic update variable"); 1427 converter.bindSymbol(*name->symbol, val); 1428 // Set the insert for the terminator operation to go at the end of the 1429 // block. 1430 mlir::Block &block = atomicUpdateOp.getRegion().back(); 1431 firOpBuilder.setInsertionPointToEnd(&block); 1432 1433 mlir::Value result = fir::getBase(converter.genExprValue( 1434 *Fortran::semantics::GetExpr(assignmentStmtExpr), stmtCtx)); 1435 // Insert the terminator: YieldOp. 1436 firOpBuilder.create<mlir::omp::YieldOp>(currentLocation, result); 1437 // Reset the insert point to before the terminator. 1438 firOpBuilder.setInsertionPointToStart(&block); 1439 } 1440 1441 static void 1442 genOmpAtomicWrite(Fortran::lower::AbstractConverter &converter, 1443 Fortran::lower::pft::Evaluation &eval, 1444 const Fortran::parser::OmpAtomicWrite &atomicWrite) { 1445 auto &firOpBuilder = converter.getFirOpBuilder(); 1446 auto currentLocation = converter.getCurrentLocation(); 1447 // Get the value and address of atomic write operands. 1448 const Fortran::parser::OmpAtomicClauseList &rightHandClauseList = 1449 std::get<2>(atomicWrite.t); 1450 const Fortran::parser::OmpAtomicClauseList &leftHandClauseList = 1451 std::get<0>(atomicWrite.t); 1452 const auto &assignmentStmtExpr = 1453 std::get<Fortran::parser::Expr>(std::get<3>(atomicWrite.t).statement.t); 1454 const auto &assignmentStmtVariable = std::get<Fortran::parser::Variable>( 1455 std::get<3>(atomicWrite.t).statement.t); 1456 Fortran::lower::StatementContext stmtCtx; 1457 mlir::Value value = fir::getBase(converter.genExprValue( 1458 *Fortran::semantics::GetExpr(assignmentStmtExpr), stmtCtx)); 1459 mlir::Value address = fir::getBase(converter.genExprAddr( 1460 *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx)); 1461 // If no hint clause is specified, the effect is as if 1462 // hint(omp_sync_hint_none) had been specified. 1463 mlir::IntegerAttr hint = nullptr; 1464 mlir::omp::ClauseMemoryOrderKindAttr memory_order = nullptr; 1465 genOmpAtomicHintAndMemoryOrderClauses(converter, leftHandClauseList, hint, 1466 memory_order); 1467 genOmpAtomicHintAndMemoryOrderClauses(converter, rightHandClauseList, hint, 1468 memory_order); 1469 firOpBuilder.create<mlir::omp::AtomicWriteOp>(currentLocation, address, value, 1470 hint, memory_order); 1471 } 1472 1473 static void genOmpAtomicRead(Fortran::lower::AbstractConverter &converter, 1474 Fortran::lower::pft::Evaluation &eval, 1475 const Fortran::parser::OmpAtomicRead &atomicRead) { 1476 auto &firOpBuilder = converter.getFirOpBuilder(); 1477 auto currentLocation = converter.getCurrentLocation(); 1478 // Get the address of atomic read operands. 1479 const Fortran::parser::OmpAtomicClauseList &rightHandClauseList = 1480 std::get<2>(atomicRead.t); 1481 const Fortran::parser::OmpAtomicClauseList &leftHandClauseList = 1482 std::get<0>(atomicRead.t); 1483 const auto &assignmentStmtExpr = 1484 std::get<Fortran::parser::Expr>(std::get<3>(atomicRead.t).statement.t); 1485 const auto &assignmentStmtVariable = std::get<Fortran::parser::Variable>( 1486 std::get<3>(atomicRead.t).statement.t); 1487 Fortran::lower::StatementContext stmtCtx; 1488 mlir::Value from_address = fir::getBase(converter.genExprAddr( 1489 *Fortran::semantics::GetExpr(assignmentStmtExpr), stmtCtx)); 1490 mlir::Value to_address = fir::getBase(converter.genExprAddr( 1491 *Fortran::semantics::GetExpr(assignmentStmtVariable), stmtCtx)); 1492 // If no hint clause is specified, the effect is as if 1493 // hint(omp_sync_hint_none) had been specified. 1494 mlir::IntegerAttr hint = nullptr; 1495 mlir::omp::ClauseMemoryOrderKindAttr memory_order = nullptr; 1496 genOmpAtomicHintAndMemoryOrderClauses(converter, leftHandClauseList, hint, 1497 memory_order); 1498 genOmpAtomicHintAndMemoryOrderClauses(converter, rightHandClauseList, hint, 1499 memory_order); 1500 firOpBuilder.create<mlir::omp::AtomicReadOp>(currentLocation, from_address, 1501 to_address, hint, memory_order); 1502 } 1503 1504 static void 1505 genOmpAtomicUpdate(Fortran::lower::AbstractConverter &converter, 1506 Fortran::lower::pft::Evaluation &eval, 1507 const Fortran::parser::OmpAtomicUpdate &atomicUpdate) { 1508 const Fortran::parser::OmpAtomicClauseList &rightHandClauseList = 1509 std::get<2>(atomicUpdate.t); 1510 const Fortran::parser::OmpAtomicClauseList &leftHandClauseList = 1511 std::get<0>(atomicUpdate.t); 1512 const auto &assignmentStmtExpr = 1513 std::get<Fortran::parser::Expr>(std::get<3>(atomicUpdate.t).statement.t); 1514 const auto &assignmentStmtVariable = std::get<Fortran::parser::Variable>( 1515 std::get<3>(atomicUpdate.t).statement.t); 1516 1517 genOmpAtomicUpdateStatement(converter, eval, assignmentStmtVariable, 1518 assignmentStmtExpr, &leftHandClauseList, 1519 &rightHandClauseList); 1520 } 1521 1522 static void genOmpAtomic(Fortran::lower::AbstractConverter &converter, 1523 Fortran::lower::pft::Evaluation &eval, 1524 const Fortran::parser::OmpAtomic &atomicConstruct) { 1525 const Fortran::parser::OmpAtomicClauseList &atomicClauseList = 1526 std::get<Fortran::parser::OmpAtomicClauseList>(atomicConstruct.t); 1527 const auto &assignmentStmtExpr = std::get<Fortran::parser::Expr>( 1528 std::get<Fortran::parser::Statement<Fortran::parser::AssignmentStmt>>( 1529 atomicConstruct.t) 1530 .statement.t); 1531 const auto &assignmentStmtVariable = std::get<Fortran::parser::Variable>( 1532 std::get<Fortran::parser::Statement<Fortran::parser::AssignmentStmt>>( 1533 atomicConstruct.t) 1534 .statement.t); 1535 // If atomic-clause is not present on the construct, the behaviour is as if 1536 // the update clause is specified 1537 genOmpAtomicUpdateStatement(converter, eval, assignmentStmtVariable, 1538 assignmentStmtExpr, &atomicClauseList, nullptr); 1539 } 1540 1541 static void 1542 genOMP(Fortran::lower::AbstractConverter &converter, 1543 Fortran::lower::pft::Evaluation &eval, 1544 const Fortran::parser::OpenMPAtomicConstruct &atomicConstruct) { 1545 std::visit(Fortran::common::visitors{ 1546 [&](const Fortran::parser::OmpAtomicRead &atomicRead) { 1547 genOmpAtomicRead(converter, eval, atomicRead); 1548 }, 1549 [&](const Fortran::parser::OmpAtomicWrite &atomicWrite) { 1550 genOmpAtomicWrite(converter, eval, atomicWrite); 1551 }, 1552 [&](const Fortran::parser::OmpAtomic &atomicConstruct) { 1553 genOmpAtomic(converter, eval, atomicConstruct); 1554 }, 1555 [&](const Fortran::parser::OmpAtomicUpdate &atomicUpdate) { 1556 genOmpAtomicUpdate(converter, eval, atomicUpdate); 1557 }, 1558 [&](const auto &) { 1559 TODO(converter.getCurrentLocation(), "Atomic capture"); 1560 }, 1561 }, 1562 atomicConstruct.u); 1563 } 1564 1565 void Fortran::lower::genOpenMPConstruct( 1566 Fortran::lower::AbstractConverter &converter, 1567 Fortran::lower::pft::Evaluation &eval, 1568 const Fortran::parser::OpenMPConstruct &ompConstruct) { 1569 1570 std::visit( 1571 common::visitors{ 1572 [&](const Fortran::parser::OpenMPStandaloneConstruct 1573 &standaloneConstruct) { 1574 genOMP(converter, eval, standaloneConstruct); 1575 }, 1576 [&](const Fortran::parser::OpenMPSectionsConstruct 1577 §ionsConstruct) { 1578 genOMP(converter, eval, sectionsConstruct); 1579 }, 1580 [&](const Fortran::parser::OpenMPSectionConstruct §ionConstruct) { 1581 genOMP(converter, eval, sectionConstruct); 1582 }, 1583 [&](const Fortran::parser::OpenMPLoopConstruct &loopConstruct) { 1584 genOMP(converter, eval, loopConstruct); 1585 }, 1586 [&](const Fortran::parser::OpenMPDeclarativeAllocate 1587 &execAllocConstruct) { 1588 TODO(converter.getCurrentLocation(), "OpenMPDeclarativeAllocate"); 1589 }, 1590 [&](const Fortran::parser::OpenMPExecutableAllocate 1591 &execAllocConstruct) { 1592 TODO(converter.getCurrentLocation(), "OpenMPExecutableAllocate"); 1593 }, 1594 [&](const Fortran::parser::OpenMPBlockConstruct &blockConstruct) { 1595 genOMP(converter, eval, blockConstruct); 1596 }, 1597 [&](const Fortran::parser::OpenMPAtomicConstruct &atomicConstruct) { 1598 genOMP(converter, eval, atomicConstruct); 1599 }, 1600 [&](const Fortran::parser::OpenMPCriticalConstruct 1601 &criticalConstruct) { 1602 genOMP(converter, eval, criticalConstruct); 1603 }, 1604 }, 1605 ompConstruct.u); 1606 } 1607 1608 void Fortran::lower::genThreadprivateOp( 1609 Fortran::lower::AbstractConverter &converter, 1610 const Fortran::lower::pft::Variable &var) { 1611 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 1612 mlir::Location currentLocation = converter.getCurrentLocation(); 1613 1614 const Fortran::semantics::Symbol &sym = var.getSymbol(); 1615 mlir::Value symThreadprivateValue; 1616 if (const Fortran::semantics::Symbol *common = 1617 Fortran::semantics::FindCommonBlockContaining(sym.GetUltimate())) { 1618 mlir::Value commonValue = converter.getSymbolAddress(*common); 1619 if (mlir::isa<mlir::omp::ThreadprivateOp>(commonValue.getDefiningOp())) { 1620 // Generate ThreadprivateOp for a common block instead of its members and 1621 // only do it once for a common block. 1622 return; 1623 } 1624 // Generate ThreadprivateOp and rebind the common block. 1625 mlir::Value commonThreadprivateValue = 1626 firOpBuilder.create<mlir::omp::ThreadprivateOp>( 1627 currentLocation, commonValue.getType(), commonValue); 1628 converter.bindSymbol(*common, commonThreadprivateValue); 1629 // Generate the threadprivate value for the common block member. 1630 symThreadprivateValue = 1631 genCommonBlockMember(converter, sym, commonThreadprivateValue); 1632 } else { 1633 mlir::Value symValue = converter.getSymbolAddress(sym); 1634 symThreadprivateValue = firOpBuilder.create<mlir::omp::ThreadprivateOp>( 1635 currentLocation, symValue.getType(), symValue); 1636 } 1637 1638 fir::ExtendedValue sexv = converter.getSymbolExtendedValue(sym); 1639 fir::ExtendedValue symThreadprivateExv = 1640 getExtendedValue(sexv, symThreadprivateValue); 1641 converter.bindSymbol(sym, symThreadprivateExv); 1642 } 1643 1644 void Fortran::lower::genOpenMPDeclarativeConstruct( 1645 Fortran::lower::AbstractConverter &converter, 1646 Fortran::lower::pft::Evaluation &eval, 1647 const Fortran::parser::OpenMPDeclarativeConstruct &ompDeclConstruct) { 1648 1649 std::visit( 1650 common::visitors{ 1651 [&](const Fortran::parser::OpenMPDeclarativeAllocate 1652 &declarativeAllocate) { 1653 TODO(converter.getCurrentLocation(), "OpenMPDeclarativeAllocate"); 1654 }, 1655 [&](const Fortran::parser::OpenMPDeclareReductionConstruct 1656 &declareReductionConstruct) { 1657 TODO(converter.getCurrentLocation(), 1658 "OpenMPDeclareReductionConstruct"); 1659 }, 1660 [&](const Fortran::parser::OpenMPDeclareSimdConstruct 1661 &declareSimdConstruct) { 1662 TODO(converter.getCurrentLocation(), "OpenMPDeclareSimdConstruct"); 1663 }, 1664 [&](const Fortran::parser::OpenMPDeclareTargetConstruct 1665 &declareTargetConstruct) { 1666 TODO(converter.getCurrentLocation(), 1667 "OpenMPDeclareTargetConstruct"); 1668 }, 1669 [&](const Fortran::parser::OpenMPThreadprivate &threadprivate) { 1670 // The directive is lowered when instantiating the variable to 1671 // support the case of threadprivate variable declared in module. 1672 }, 1673 }, 1674 ompDeclConstruct.u); 1675 } 1676 1677 // Generate an OpenMP reduction operation. This implementation finds the chain : 1678 // load reduction var -> reduction_operation -> store reduction var and replaces 1679 // it with the reduction operation. 1680 // TODO: Currently assumes it is an integer addition reduction. Generalize this 1681 // for various reduction operation types. 1682 // TODO: Generate the reduction operation during lowering instead of creating 1683 // and removing operations since this is not a robust approach. Also, removing 1684 // ops in the builder (instead of a rewriter) is probably not the best approach. 1685 void Fortran::lower::genOpenMPReduction( 1686 Fortran::lower::AbstractConverter &converter, 1687 const Fortran::parser::OmpClauseList &clauseList) { 1688 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 1689 1690 for (const auto &clause : clauseList.v) { 1691 if (const auto &reductionClause = 1692 std::get_if<Fortran::parser::OmpClause::Reduction>(&clause.u)) { 1693 const auto &redOperator{std::get<Fortran::parser::OmpReductionOperator>( 1694 reductionClause->v.t)}; 1695 const auto &objectList{ 1696 std::get<Fortran::parser::OmpObjectList>(reductionClause->v.t)}; 1697 if (auto reductionOp = 1698 std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) { 1699 const auto &intrinsicOp{ 1700 std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>( 1701 reductionOp->u)}; 1702 if (intrinsicOp != 1703 Fortran::parser::DefinedOperator::IntrinsicOperator::Add) 1704 continue; 1705 for (const auto &ompObject : objectList.v) { 1706 if (const auto *name{ 1707 Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) { 1708 if (const auto *symbol{name->symbol}) { 1709 mlir::Value symVal = converter.getSymbolAddress(*symbol); 1710 mlir::Type redType = 1711 symVal.getType().cast<fir::ReferenceType>().getEleTy(); 1712 if (!redType.isIntOrIndex()) 1713 continue; 1714 for (mlir::OpOperand &use1 : symVal.getUses()) { 1715 if (auto load = mlir::dyn_cast<fir::LoadOp>(use1.getOwner())) { 1716 mlir::Value loadVal = load.getRes(); 1717 for (mlir::OpOperand &use2 : loadVal.getUses()) { 1718 if (auto add = mlir::dyn_cast<mlir::arith::AddIOp>( 1719 use2.getOwner())) { 1720 mlir::Value addRes = add.getResult(); 1721 for (mlir::OpOperand &use3 : addRes.getUses()) { 1722 if (auto store = 1723 mlir::dyn_cast<fir::StoreOp>(use3.getOwner())) { 1724 if (store.getMemref() == symVal) { 1725 // Chain found! Now replace load->reduction->store 1726 // with the OpenMP reduction operation. 1727 mlir::OpBuilder::InsertPoint insertPtDel = 1728 firOpBuilder.saveInsertionPoint(); 1729 firOpBuilder.setInsertionPoint(add); 1730 if (add.getLhs() == loadVal) { 1731 firOpBuilder.create<mlir::omp::ReductionOp>( 1732 add.getLoc(), add.getRhs(), symVal); 1733 } else { 1734 firOpBuilder.create<mlir::omp::ReductionOp>( 1735 add.getLoc(), add.getLhs(), symVal); 1736 } 1737 store.erase(); 1738 add.erase(); 1739 load.erase(); 1740 firOpBuilder.restoreInsertionPoint(insertPtDel); 1741 } 1742 } 1743 } 1744 } 1745 } 1746 } 1747 } 1748 } 1749 } 1750 } 1751 } 1752 } 1753 } 1754 } 1755