1 //===-- OpenACC.cpp -- OpenACC directive lowering -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "flang/Lower/OpenACC.h" 14 #include "flang/Common/idioms.h" 15 #include "flang/Lower/Bridge.h" 16 #include "flang/Lower/PFTBuilder.h" 17 #include "flang/Lower/StatementContext.h" 18 #include "flang/Lower/Todo.h" 19 #include "flang/Optimizer/Builder/BoxValue.h" 20 #include "flang/Optimizer/Builder/FIRBuilder.h" 21 #include "flang/Parser/parse-tree.h" 22 #include "flang/Semantics/tools.h" 23 #include "mlir/Dialect/OpenACC/OpenACC.h" 24 #include "llvm/Frontend/OpenACC/ACC.h.inc" 25 26 using namespace mlir; 27 28 // Special value for * passed in device_type or gang clauses. 29 static constexpr std::int64_t starCst{-1}; 30 31 static const Fortran::parser::Name * 32 getDesignatorNameIfDataRef(const Fortran::parser::Designator &designator) { 33 const auto *dataRef{std::get_if<Fortran::parser::DataRef>(&designator.u)}; 34 return dataRef ? std::get_if<Fortran::parser::Name>(&dataRef->u) : nullptr; 35 } 36 37 static void genObjectList(const Fortran::parser::AccObjectList &objectList, 38 Fortran::lower::AbstractConverter &converter, 39 SmallVectorImpl<Value> &operands) { 40 for (const auto &accObject : objectList.v) { 41 std::visit( 42 Fortran::common::visitors{ 43 [&](const Fortran::parser::Designator &designator) { 44 if (const auto *name = getDesignatorNameIfDataRef(designator)) { 45 const auto variable = converter.getSymbolAddress(*name->symbol); 46 operands.push_back(variable); 47 } 48 }, 49 [&](const Fortran::parser::Name &name) { 50 const auto variable = converter.getSymbolAddress(*name.symbol); 51 operands.push_back(variable); 52 }}, 53 accObject.u); 54 } 55 } 56 57 template <typename Clause> 58 static void 59 genObjectListWithModifier(const Clause *x, 60 Fortran::lower::AbstractConverter &converter, 61 Fortran::parser::AccDataModifier::Modifier mod, 62 SmallVectorImpl<Value> &operandsWithModifier, 63 SmallVectorImpl<Value> &operands) { 64 const Fortran::parser::AccObjectListWithModifier &listWithModifier = x->v; 65 const Fortran::parser::AccObjectList &accObjectList = 66 std::get<Fortran::parser::AccObjectList>(listWithModifier.t); 67 const auto &modifier = 68 std::get<std::optional<Fortran::parser::AccDataModifier>>( 69 listWithModifier.t); 70 if (modifier && (*modifier).v == mod) { 71 genObjectList(accObjectList, converter, operandsWithModifier); 72 } else { 73 genObjectList(accObjectList, converter, operands); 74 } 75 } 76 77 static void addOperands(SmallVectorImpl<Value> &operands, 78 SmallVectorImpl<int32_t> &operandSegments, 79 const SmallVectorImpl<Value> &clauseOperands) { 80 operands.append(clauseOperands.begin(), clauseOperands.end()); 81 operandSegments.push_back(clauseOperands.size()); 82 } 83 84 static void addOperand(SmallVectorImpl<Value> &operands, 85 SmallVectorImpl<int32_t> &operandSegments, 86 const Value &clauseOperand) { 87 if (clauseOperand) { 88 operands.push_back(clauseOperand); 89 operandSegments.push_back(1); 90 } else { 91 operandSegments.push_back(0); 92 } 93 } 94 95 template <typename Op, typename Terminator> 96 static Op createRegionOp(fir::FirOpBuilder &builder, mlir::Location loc, 97 const SmallVectorImpl<Value> &operands, 98 const SmallVectorImpl<int32_t> &operandSegments) { 99 llvm::ArrayRef<mlir::Type> argTy; 100 Op op = builder.create<Op>(loc, argTy, operands); 101 builder.createBlock(&op.getRegion()); 102 auto &block = op.getRegion().back(); 103 builder.setInsertionPointToStart(&block); 104 builder.create<Terminator>(loc); 105 106 op->setAttr(Op::getOperandSegmentSizeAttr(), 107 builder.getI32VectorAttr(operandSegments)); 108 109 // Place the insertion point to the start of the first block. 110 builder.setInsertionPointToStart(&block); 111 112 return op; 113 } 114 115 template <typename Op> 116 static Op createSimpleOp(fir::FirOpBuilder &builder, mlir::Location loc, 117 const SmallVectorImpl<Value> &operands, 118 const SmallVectorImpl<int32_t> &operandSegments) { 119 llvm::ArrayRef<mlir::Type> argTy; 120 Op op = builder.create<Op>(loc, argTy, operands); 121 op->setAttr(Op::getOperandSegmentSizeAttr(), 122 builder.getI32VectorAttr(operandSegments)); 123 return op; 124 } 125 126 static void genAsyncClause(Fortran::lower::AbstractConverter &converter, 127 const Fortran::parser::AccClause::Async *asyncClause, 128 mlir::Value &async, bool &addAsyncAttr, 129 Fortran::lower::StatementContext &stmtCtx) { 130 const auto &asyncClauseValue = asyncClause->v; 131 if (asyncClauseValue) { // async has a value. 132 async = fir::getBase(converter.genExprValue( 133 *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx)); 134 } else { 135 addAsyncAttr = true; 136 } 137 } 138 139 static void genDeviceTypeClause( 140 Fortran::lower::AbstractConverter &converter, 141 const Fortran::parser::AccClause::DeviceType *deviceTypeClause, 142 SmallVectorImpl<mlir::Value> &operands, 143 Fortran::lower::StatementContext &stmtCtx) { 144 const auto &deviceTypeValue = deviceTypeClause->v; 145 if (deviceTypeValue) { 146 for (const auto &scalarIntExpr : *deviceTypeValue) { 147 mlir::Value expr = fir::getBase(converter.genExprValue( 148 *Fortran::semantics::GetExpr(scalarIntExpr), stmtCtx)); 149 operands.push_back(expr); 150 } 151 } else { 152 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 153 // * was passed as value and will be represented as a special constant. 154 mlir::Value star = firOpBuilder.createIntegerConstant( 155 converter.getCurrentLocation(), firOpBuilder.getIndexType(), starCst); 156 operands.push_back(star); 157 } 158 } 159 160 static void genIfClause(Fortran::lower::AbstractConverter &converter, 161 const Fortran::parser::AccClause::If *ifClause, 162 mlir::Value &ifCond, 163 Fortran::lower::StatementContext &stmtCtx) { 164 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 165 Value cond = fir::getBase(converter.genExprValue( 166 *Fortran::semantics::GetExpr(ifClause->v), stmtCtx)); 167 ifCond = firOpBuilder.createConvert(converter.getCurrentLocation(), 168 firOpBuilder.getI1Type(), cond); 169 } 170 171 static void genWaitClause(Fortran::lower::AbstractConverter &converter, 172 const Fortran::parser::AccClause::Wait *waitClause, 173 SmallVectorImpl<mlir::Value> &operands, 174 mlir::Value &waitDevnum, bool &addWaitAttr, 175 Fortran::lower::StatementContext &stmtCtx) { 176 const auto &waitClauseValue = waitClause->v; 177 if (waitClauseValue) { // wait has a value. 178 const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue; 179 const std::list<Fortran::parser::ScalarIntExpr> &waitList = 180 std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t); 181 for (const Fortran::parser::ScalarIntExpr &value : waitList) { 182 mlir::Value v = fir::getBase( 183 converter.genExprValue(*Fortran::semantics::GetExpr(value), stmtCtx)); 184 operands.push_back(v); 185 } 186 187 const std::optional<Fortran::parser::ScalarIntExpr> &waitDevnumValue = 188 std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t); 189 if (waitDevnumValue) 190 waitDevnum = fir::getBase(converter.genExprValue( 191 *Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx)); 192 } else { 193 addWaitAttr = true; 194 } 195 } 196 197 static void genACC(Fortran::lower::AbstractConverter &converter, 198 Fortran::lower::pft::Evaluation &eval, 199 const Fortran::parser::OpenACCLoopConstruct &loopConstruct) { 200 Fortran::lower::StatementContext stmtCtx; 201 const auto &beginLoopDirective = 202 std::get<Fortran::parser::AccBeginLoopDirective>(loopConstruct.t); 203 const auto &loopDirective = 204 std::get<Fortran::parser::AccLoopDirective>(beginLoopDirective.t); 205 206 if (loopDirective.v == llvm::acc::ACCD_loop) { 207 auto &firOpBuilder = converter.getFirOpBuilder(); 208 auto currentLocation = converter.getCurrentLocation(); 209 210 // Add attribute extracted from clauses. 211 const auto &accClauseList = 212 std::get<Fortran::parser::AccClauseList>(beginLoopDirective.t); 213 214 mlir::Value workerNum; 215 mlir::Value vectorLength; 216 mlir::Value gangNum; 217 mlir::Value gangStatic; 218 SmallVector<Value, 2> tileOperands, privateOperands, reductionOperands; 219 std::int64_t executionMapping = mlir::acc::OpenACCExecMapping::NONE; 220 221 // Lower clauses values mapped to operands. 222 for (const auto &clause : accClauseList.v) { 223 if (const auto *gangClause = 224 std::get_if<Fortran::parser::AccClause::Gang>(&clause.u)) { 225 if (gangClause->v) { 226 const Fortran::parser::AccGangArgument &x = *gangClause->v; 227 if (const auto &gangNumValue = 228 std::get<std::optional<Fortran::parser::ScalarIntExpr>>( 229 x.t)) { 230 gangNum = fir::getBase(converter.genExprValue( 231 *Fortran::semantics::GetExpr(gangNumValue.value()), stmtCtx)); 232 } 233 if (const auto &gangStaticValue = 234 std::get<std::optional<Fortran::parser::AccSizeExpr>>(x.t)) { 235 const auto &expr = 236 std::get<std::optional<Fortran::parser::ScalarIntExpr>>( 237 gangStaticValue.value().t); 238 if (expr) { 239 gangStatic = fir::getBase(converter.genExprValue( 240 *Fortran::semantics::GetExpr(*expr), stmtCtx)); 241 } else { 242 // * was passed as value and will be represented as a -1 constant 243 // integer. 244 gangStatic = firOpBuilder.createIntegerConstant( 245 currentLocation, firOpBuilder.getIntegerType(32), 246 /* STAR */ -1); 247 } 248 } 249 } 250 executionMapping |= mlir::acc::OpenACCExecMapping::GANG; 251 } else if (const auto *workerClause = 252 std::get_if<Fortran::parser::AccClause::Worker>( 253 &clause.u)) { 254 if (workerClause->v) { 255 workerNum = fir::getBase(converter.genExprValue( 256 *Fortran::semantics::GetExpr(*workerClause->v), stmtCtx)); 257 } 258 executionMapping |= mlir::acc::OpenACCExecMapping::WORKER; 259 } else if (const auto *vectorClause = 260 std::get_if<Fortran::parser::AccClause::Vector>( 261 &clause.u)) { 262 if (vectorClause->v) { 263 vectorLength = fir::getBase(converter.genExprValue( 264 *Fortran::semantics::GetExpr(*vectorClause->v), stmtCtx)); 265 } 266 executionMapping |= mlir::acc::OpenACCExecMapping::VECTOR; 267 } else if (const auto *tileClause = 268 std::get_if<Fortran::parser::AccClause::Tile>(&clause.u)) { 269 const Fortran::parser::AccTileExprList &accTileExprList = tileClause->v; 270 for (const auto &accTileExpr : accTileExprList.v) { 271 const auto &expr = 272 std::get<std::optional<Fortran::parser::ScalarIntConstantExpr>>( 273 accTileExpr.t); 274 if (expr) { 275 tileOperands.push_back(fir::getBase(converter.genExprValue( 276 *Fortran::semantics::GetExpr(*expr), stmtCtx))); 277 } else { 278 // * was passed as value and will be represented as a -1 constant 279 // integer. 280 mlir::Value tileStar = firOpBuilder.createIntegerConstant( 281 currentLocation, firOpBuilder.getIntegerType(32), 282 /* STAR */ -1); 283 tileOperands.push_back(tileStar); 284 } 285 } 286 } else if (const auto *privateClause = 287 std::get_if<Fortran::parser::AccClause::Private>( 288 &clause.u)) { 289 genObjectList(privateClause->v, converter, privateOperands); 290 } 291 // Reduction clause is left out for the moment as the clause will probably 292 // end up having its own operation. 293 } 294 295 // Prepare the operand segement size attribute and the operands value range. 296 SmallVector<Value, 8> operands; 297 SmallVector<int32_t, 8> operandSegments; 298 addOperand(operands, operandSegments, gangNum); 299 addOperand(operands, operandSegments, gangStatic); 300 addOperand(operands, operandSegments, workerNum); 301 addOperand(operands, operandSegments, vectorLength); 302 addOperands(operands, operandSegments, tileOperands); 303 addOperands(operands, operandSegments, privateOperands); 304 addOperands(operands, operandSegments, reductionOperands); 305 306 auto loopOp = createRegionOp<mlir::acc::LoopOp, mlir::acc::YieldOp>( 307 firOpBuilder, currentLocation, operands, operandSegments); 308 309 loopOp->setAttr(mlir::acc::LoopOp::getExecutionMappingAttrName(), 310 firOpBuilder.getI64IntegerAttr(executionMapping)); 311 312 // Lower clauses mapped to attributes 313 for (const auto &clause : accClauseList.v) { 314 if (const auto *collapseClause = 315 std::get_if<Fortran::parser::AccClause::Collapse>(&clause.u)) { 316 const auto *expr = Fortran::semantics::GetExpr(collapseClause->v); 317 const auto collapseValue = Fortran::evaluate::ToInt64(*expr); 318 if (collapseValue) { 319 loopOp->setAttr(mlir::acc::LoopOp::getCollapseAttrName(), 320 firOpBuilder.getI64IntegerAttr(*collapseValue)); 321 } 322 } else if (std::get_if<Fortran::parser::AccClause::Seq>(&clause.u)) { 323 loopOp->setAttr(mlir::acc::LoopOp::getSeqAttrName(), 324 firOpBuilder.getUnitAttr()); 325 } else if (std::get_if<Fortran::parser::AccClause::Independent>( 326 &clause.u)) { 327 loopOp->setAttr(mlir::acc::LoopOp::getIndependentAttrName(), 328 firOpBuilder.getUnitAttr()); 329 } else if (std::get_if<Fortran::parser::AccClause::Auto>(&clause.u)) { 330 loopOp->setAttr(mlir::acc::LoopOp::getAutoAttrName(), 331 firOpBuilder.getUnitAttr()); 332 } 333 } 334 } 335 } 336 337 static void 338 genACCParallelOp(Fortran::lower::AbstractConverter &converter, 339 const Fortran::parser::AccClauseList &accClauseList) { 340 mlir::Value async; 341 mlir::Value numGangs; 342 mlir::Value numWorkers; 343 mlir::Value vectorLength; 344 mlir::Value ifCond; 345 mlir::Value selfCond; 346 SmallVector<Value, 2> waitOperands, reductionOperands, copyOperands, 347 copyinOperands, copyinReadonlyOperands, copyoutOperands, 348 copyoutZeroOperands, createOperands, createZeroOperands, noCreateOperands, 349 presentOperands, devicePtrOperands, attachOperands, privateOperands, 350 firstprivateOperands; 351 352 // Async, wait and self clause have optional values but can be present with 353 // no value as well. When there is no value, the op has an attribute to 354 // represent the clause. 355 bool addAsyncAttr = false; 356 bool addWaitAttr = false; 357 bool addSelfAttr = false; 358 359 auto &firOpBuilder = converter.getFirOpBuilder(); 360 auto currentLocation = converter.getCurrentLocation(); 361 Fortran::lower::StatementContext stmtCtx; 362 363 // Lower clauses values mapped to operands. 364 // Keep track of each group of operands separatly as clauses can appear 365 // more than once. 366 for (const auto &clause : accClauseList.v) { 367 if (const auto *asyncClause = 368 std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) { 369 const auto &asyncClauseValue = asyncClause->v; 370 if (asyncClauseValue) { // async has a value. 371 async = fir::getBase(converter.genExprValue( 372 *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx)); 373 } else { 374 addAsyncAttr = true; 375 } 376 } else if (const auto *waitClause = 377 std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) { 378 const auto &waitClauseValue = waitClause->v; 379 if (waitClauseValue) { // wait has a value. 380 const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue; 381 const std::list<Fortran::parser::ScalarIntExpr> &waitList = 382 std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t); 383 for (const Fortran::parser::ScalarIntExpr &value : waitList) { 384 Value v = fir::getBase(converter.genExprValue( 385 *Fortran::semantics::GetExpr(value), stmtCtx)); 386 waitOperands.push_back(v); 387 } 388 } else { 389 addWaitAttr = true; 390 } 391 } else if (const auto *numGangsClause = 392 std::get_if<Fortran::parser::AccClause::NumGangs>( 393 &clause.u)) { 394 numGangs = fir::getBase(converter.genExprValue( 395 *Fortran::semantics::GetExpr(numGangsClause->v), stmtCtx)); 396 } else if (const auto *numWorkersClause = 397 std::get_if<Fortran::parser::AccClause::NumWorkers>( 398 &clause.u)) { 399 numWorkers = fir::getBase(converter.genExprValue( 400 *Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx)); 401 } else if (const auto *vectorLengthClause = 402 std::get_if<Fortran::parser::AccClause::VectorLength>( 403 &clause.u)) { 404 vectorLength = fir::getBase(converter.genExprValue( 405 *Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx)); 406 } else if (const auto *ifClause = 407 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) { 408 Value cond = fir::getBase(converter.genExprValue( 409 *Fortran::semantics::GetExpr(ifClause->v), stmtCtx)); 410 ifCond = firOpBuilder.createConvert(currentLocation, 411 firOpBuilder.getI1Type(), cond); 412 } else if (const auto *selfClause = 413 std::get_if<Fortran::parser::AccClause::Self>(&clause.u)) { 414 const Fortran::parser::AccSelfClause &accSelfClause = selfClause->v; 415 if (const auto *optCondition = 416 std::get_if<std::optional<Fortran::parser::ScalarLogicalExpr>>( 417 &accSelfClause.u)) { 418 if (*optCondition) { 419 Value cond = fir::getBase(converter.genExprValue( 420 *Fortran::semantics::GetExpr(*optCondition), stmtCtx)); 421 selfCond = firOpBuilder.createConvert(currentLocation, 422 firOpBuilder.getI1Type(), cond); 423 } else { 424 addSelfAttr = true; 425 } 426 } 427 } else if (const auto *copyClause = 428 std::get_if<Fortran::parser::AccClause::Copy>(&clause.u)) { 429 genObjectList(copyClause->v, converter, copyOperands); 430 } else if (const auto *copyinClause = 431 std::get_if<Fortran::parser::AccClause::Copyin>(&clause.u)) { 432 genObjectListWithModifier<Fortran::parser::AccClause::Copyin>( 433 copyinClause, converter, 434 Fortran::parser::AccDataModifier::Modifier::ReadOnly, 435 copyinReadonlyOperands, copyinOperands); 436 } else if (const auto *copyoutClause = 437 std::get_if<Fortran::parser::AccClause::Copyout>( 438 &clause.u)) { 439 genObjectListWithModifier<Fortran::parser::AccClause::Copyout>( 440 copyoutClause, converter, 441 Fortran::parser::AccDataModifier::Modifier::Zero, copyoutZeroOperands, 442 copyoutOperands); 443 } else if (const auto *createClause = 444 std::get_if<Fortran::parser::AccClause::Create>(&clause.u)) { 445 genObjectListWithModifier<Fortran::parser::AccClause::Create>( 446 createClause, converter, 447 Fortran::parser::AccDataModifier::Modifier::Zero, createZeroOperands, 448 createOperands); 449 } else if (const auto *noCreateClause = 450 std::get_if<Fortran::parser::AccClause::NoCreate>( 451 &clause.u)) { 452 genObjectList(noCreateClause->v, converter, noCreateOperands); 453 } else if (const auto *presentClause = 454 std::get_if<Fortran::parser::AccClause::Present>( 455 &clause.u)) { 456 genObjectList(presentClause->v, converter, presentOperands); 457 } else if (const auto *devicePtrClause = 458 std::get_if<Fortran::parser::AccClause::Deviceptr>( 459 &clause.u)) { 460 genObjectList(devicePtrClause->v, converter, devicePtrOperands); 461 } else if (const auto *attachClause = 462 std::get_if<Fortran::parser::AccClause::Attach>(&clause.u)) { 463 genObjectList(attachClause->v, converter, attachOperands); 464 } else if (const auto *privateClause = 465 std::get_if<Fortran::parser::AccClause::Private>( 466 &clause.u)) { 467 genObjectList(privateClause->v, converter, privateOperands); 468 } else if (const auto *firstprivateClause = 469 std::get_if<Fortran::parser::AccClause::Firstprivate>( 470 &clause.u)) { 471 genObjectList(firstprivateClause->v, converter, firstprivateOperands); 472 } 473 } 474 475 // Prepare the operand segement size attribute and the operands value range. 476 SmallVector<Value, 8> operands; 477 SmallVector<int32_t, 8> operandSegments; 478 addOperand(operands, operandSegments, async); 479 addOperands(operands, operandSegments, waitOperands); 480 addOperand(operands, operandSegments, numGangs); 481 addOperand(operands, operandSegments, numWorkers); 482 addOperand(operands, operandSegments, vectorLength); 483 addOperand(operands, operandSegments, ifCond); 484 addOperand(operands, operandSegments, selfCond); 485 addOperands(operands, operandSegments, reductionOperands); 486 addOperands(operands, operandSegments, copyOperands); 487 addOperands(operands, operandSegments, copyinOperands); 488 addOperands(operands, operandSegments, copyinReadonlyOperands); 489 addOperands(operands, operandSegments, copyoutOperands); 490 addOperands(operands, operandSegments, copyoutZeroOperands); 491 addOperands(operands, operandSegments, createOperands); 492 addOperands(operands, operandSegments, createZeroOperands); 493 addOperands(operands, operandSegments, noCreateOperands); 494 addOperands(operands, operandSegments, presentOperands); 495 addOperands(operands, operandSegments, devicePtrOperands); 496 addOperands(operands, operandSegments, attachOperands); 497 addOperands(operands, operandSegments, privateOperands); 498 addOperands(operands, operandSegments, firstprivateOperands); 499 500 auto parallelOp = createRegionOp<mlir::acc::ParallelOp, mlir::acc::YieldOp>( 501 firOpBuilder, currentLocation, operands, operandSegments); 502 503 if (addAsyncAttr) 504 parallelOp->setAttr(mlir::acc::ParallelOp::getAsyncAttrName(), 505 firOpBuilder.getUnitAttr()); 506 if (addWaitAttr) 507 parallelOp->setAttr(mlir::acc::ParallelOp::getWaitAttrName(), 508 firOpBuilder.getUnitAttr()); 509 if (addSelfAttr) 510 parallelOp->setAttr(mlir::acc::ParallelOp::getSelfAttrName(), 511 firOpBuilder.getUnitAttr()); 512 } 513 514 static void genACCDataOp(Fortran::lower::AbstractConverter &converter, 515 const Fortran::parser::AccClauseList &accClauseList) { 516 mlir::Value ifCond; 517 SmallVector<mlir::Value> copyOperands, copyinOperands, copyinReadonlyOperands, 518 copyoutOperands, copyoutZeroOperands, createOperands, createZeroOperands, 519 noCreateOperands, presentOperands, deviceptrOperands, attachOperands; 520 521 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 522 mlir::Location currentLocation = converter.getCurrentLocation(); 523 Fortran::lower::StatementContext stmtCtx; 524 525 // Lower clauses values mapped to operands. 526 // Keep track of each group of operands separatly as clauses can appear 527 // more than once. 528 for (const Fortran::parser::AccClause &clause : accClauseList.v) { 529 if (const auto *ifClause = 530 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) { 531 genIfClause(converter, ifClause, ifCond, stmtCtx); 532 } else if (const auto *copyClause = 533 std::get_if<Fortran::parser::AccClause::Copy>(&clause.u)) { 534 genObjectList(copyClause->v, converter, copyOperands); 535 } else if (const auto *copyinClause = 536 std::get_if<Fortran::parser::AccClause::Copyin>(&clause.u)) { 537 genObjectListWithModifier<Fortran::parser::AccClause::Copyin>( 538 copyinClause, converter, 539 Fortran::parser::AccDataModifier::Modifier::ReadOnly, 540 copyinReadonlyOperands, copyinOperands); 541 } else if (const auto *copyoutClause = 542 std::get_if<Fortran::parser::AccClause::Copyout>( 543 &clause.u)) { 544 genObjectListWithModifier<Fortran::parser::AccClause::Copyout>( 545 copyoutClause, converter, 546 Fortran::parser::AccDataModifier::Modifier::Zero, copyoutZeroOperands, 547 copyoutOperands); 548 } else if (const auto *createClause = 549 std::get_if<Fortran::parser::AccClause::Create>(&clause.u)) { 550 genObjectListWithModifier<Fortran::parser::AccClause::Create>( 551 createClause, converter, 552 Fortran::parser::AccDataModifier::Modifier::Zero, createZeroOperands, 553 createOperands); 554 } else if (const auto *noCreateClause = 555 std::get_if<Fortran::parser::AccClause::NoCreate>( 556 &clause.u)) { 557 genObjectList(noCreateClause->v, converter, noCreateOperands); 558 } else if (const auto *presentClause = 559 std::get_if<Fortran::parser::AccClause::Present>( 560 &clause.u)) { 561 genObjectList(presentClause->v, converter, presentOperands); 562 } else if (const auto *deviceptrClause = 563 std::get_if<Fortran::parser::AccClause::Deviceptr>( 564 &clause.u)) { 565 genObjectList(deviceptrClause->v, converter, deviceptrOperands); 566 } else if (const auto *attachClause = 567 std::get_if<Fortran::parser::AccClause::Attach>(&clause.u)) { 568 genObjectList(attachClause->v, converter, attachOperands); 569 } 570 } 571 572 // Prepare the operand segement size attribute and the operands value range. 573 SmallVector<mlir::Value> operands; 574 SmallVector<int32_t> operandSegments; 575 addOperand(operands, operandSegments, ifCond); 576 addOperands(operands, operandSegments, copyOperands); 577 addOperands(operands, operandSegments, copyinOperands); 578 addOperands(operands, operandSegments, copyinReadonlyOperands); 579 addOperands(operands, operandSegments, copyoutOperands); 580 addOperands(operands, operandSegments, copyoutZeroOperands); 581 addOperands(operands, operandSegments, createOperands); 582 addOperands(operands, operandSegments, createZeroOperands); 583 addOperands(operands, operandSegments, noCreateOperands); 584 addOperands(operands, operandSegments, presentOperands); 585 addOperands(operands, operandSegments, deviceptrOperands); 586 addOperands(operands, operandSegments, attachOperands); 587 588 createRegionOp<mlir::acc::DataOp, mlir::acc::TerminatorOp>( 589 firOpBuilder, currentLocation, operands, operandSegments); 590 } 591 592 static void 593 genACC(Fortran::lower::AbstractConverter &converter, 594 Fortran::lower::pft::Evaluation &eval, 595 const Fortran::parser::OpenACCBlockConstruct &blockConstruct) { 596 const auto &beginBlockDirective = 597 std::get<Fortran::parser::AccBeginBlockDirective>(blockConstruct.t); 598 const auto &blockDirective = 599 std::get<Fortran::parser::AccBlockDirective>(beginBlockDirective.t); 600 const auto &accClauseList = 601 std::get<Fortran::parser::AccClauseList>(beginBlockDirective.t); 602 603 if (blockDirective.v == llvm::acc::ACCD_parallel) { 604 genACCParallelOp(converter, accClauseList); 605 } else if (blockDirective.v == llvm::acc::ACCD_data) { 606 genACCDataOp(converter, accClauseList); 607 } 608 } 609 610 static void 611 genACCEnterDataOp(Fortran::lower::AbstractConverter &converter, 612 const Fortran::parser::AccClauseList &accClauseList) { 613 mlir::Value ifCond, async, waitDevnum; 614 SmallVector<mlir::Value> copyinOperands, createOperands, createZeroOperands, 615 attachOperands, waitOperands; 616 617 // Async, wait and self clause have optional values but can be present with 618 // no value as well. When there is no value, the op has an attribute to 619 // represent the clause. 620 bool addAsyncAttr = false; 621 bool addWaitAttr = false; 622 623 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 624 mlir::Location currentLocation = converter.getCurrentLocation(); 625 Fortran::lower::StatementContext stmtCtx; 626 627 // Lower clauses values mapped to operands. 628 // Keep track of each group of operands separatly as clauses can appear 629 // more than once. 630 for (const Fortran::parser::AccClause &clause : accClauseList.v) { 631 if (const auto *ifClause = 632 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) { 633 genIfClause(converter, ifClause, ifCond, stmtCtx); 634 } else if (const auto *asyncClause = 635 std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) { 636 genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx); 637 } else if (const auto *waitClause = 638 std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) { 639 genWaitClause(converter, waitClause, waitOperands, waitDevnum, 640 addWaitAttr, stmtCtx); 641 } else if (const auto *copyinClause = 642 std::get_if<Fortran::parser::AccClause::Copyin>(&clause.u)) { 643 const Fortran::parser::AccObjectListWithModifier &listWithModifier = 644 copyinClause->v; 645 const Fortran::parser::AccObjectList &accObjectList = 646 std::get<Fortran::parser::AccObjectList>(listWithModifier.t); 647 genObjectList(accObjectList, converter, copyinOperands); 648 } else if (const auto *createClause = 649 std::get_if<Fortran::parser::AccClause::Create>(&clause.u)) { 650 genObjectListWithModifier<Fortran::parser::AccClause::Create>( 651 createClause, converter, 652 Fortran::parser::AccDataModifier::Modifier::Zero, createZeroOperands, 653 createOperands); 654 } else if (const auto *attachClause = 655 std::get_if<Fortran::parser::AccClause::Attach>(&clause.u)) { 656 genObjectList(attachClause->v, converter, attachOperands); 657 } else { 658 llvm::report_fatal_error( 659 "Unknown clause in ENTER DATA directive lowering"); 660 } 661 } 662 663 // Prepare the operand segement size attribute and the operands value range. 664 SmallVector<mlir::Value, 16> operands; 665 SmallVector<int32_t, 8> operandSegments; 666 addOperand(operands, operandSegments, ifCond); 667 addOperand(operands, operandSegments, async); 668 addOperand(operands, operandSegments, waitDevnum); 669 addOperands(operands, operandSegments, waitOperands); 670 addOperands(operands, operandSegments, copyinOperands); 671 addOperands(operands, operandSegments, createOperands); 672 addOperands(operands, operandSegments, createZeroOperands); 673 addOperands(operands, operandSegments, attachOperands); 674 675 mlir::acc::EnterDataOp enterDataOp = createSimpleOp<mlir::acc::EnterDataOp>( 676 firOpBuilder, currentLocation, operands, operandSegments); 677 678 if (addAsyncAttr) 679 enterDataOp.asyncAttr(firOpBuilder.getUnitAttr()); 680 if (addWaitAttr) 681 enterDataOp.waitAttr(firOpBuilder.getUnitAttr()); 682 } 683 684 static void 685 genACCExitDataOp(Fortran::lower::AbstractConverter &converter, 686 const Fortran::parser::AccClauseList &accClauseList) { 687 mlir::Value ifCond, async, waitDevnum; 688 SmallVector<mlir::Value> copyoutOperands, deleteOperands, detachOperands, 689 waitOperands; 690 691 // Async and wait clause have optional values but can be present with 692 // no value as well. When there is no value, the op has an attribute to 693 // represent the clause. 694 bool addAsyncAttr = false; 695 bool addWaitAttr = false; 696 bool addFinalizeAttr = false; 697 698 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 699 mlir::Location currentLocation = converter.getCurrentLocation(); 700 Fortran::lower::StatementContext stmtCtx; 701 702 // Lower clauses values mapped to operands. 703 // Keep track of each group of operands separatly as clauses can appear 704 // more than once. 705 for (const Fortran::parser::AccClause &clause : accClauseList.v) { 706 if (const auto *ifClause = 707 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) { 708 genIfClause(converter, ifClause, ifCond, stmtCtx); 709 } else if (const auto *asyncClause = 710 std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) { 711 genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx); 712 } else if (const auto *waitClause = 713 std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) { 714 genWaitClause(converter, waitClause, waitOperands, waitDevnum, 715 addWaitAttr, stmtCtx); 716 } else if (const auto *copyoutClause = 717 std::get_if<Fortran::parser::AccClause::Copyout>( 718 &clause.u)) { 719 const Fortran::parser::AccObjectListWithModifier &listWithModifier = 720 copyoutClause->v; 721 const Fortran::parser::AccObjectList &accObjectList = 722 std::get<Fortran::parser::AccObjectList>(listWithModifier.t); 723 genObjectList(accObjectList, converter, copyoutOperands); 724 } else if (const auto *deleteClause = 725 std::get_if<Fortran::parser::AccClause::Delete>(&clause.u)) { 726 genObjectList(deleteClause->v, converter, deleteOperands); 727 } else if (const auto *detachClause = 728 std::get_if<Fortran::parser::AccClause::Detach>(&clause.u)) { 729 genObjectList(detachClause->v, converter, detachOperands); 730 } else if (std::get_if<Fortran::parser::AccClause::Finalize>(&clause.u)) { 731 addFinalizeAttr = true; 732 } 733 } 734 735 // Prepare the operand segement size attribute and the operands value range. 736 SmallVector<mlir::Value, 14> operands; 737 SmallVector<int32_t, 7> operandSegments; 738 addOperand(operands, operandSegments, ifCond); 739 addOperand(operands, operandSegments, async); 740 addOperand(operands, operandSegments, waitDevnum); 741 addOperands(operands, operandSegments, waitOperands); 742 addOperands(operands, operandSegments, copyoutOperands); 743 addOperands(operands, operandSegments, deleteOperands); 744 addOperands(operands, operandSegments, detachOperands); 745 746 mlir::acc::ExitDataOp exitDataOp = createSimpleOp<mlir::acc::ExitDataOp>( 747 firOpBuilder, currentLocation, operands, operandSegments); 748 749 if (addAsyncAttr) 750 exitDataOp.asyncAttr(firOpBuilder.getUnitAttr()); 751 if (addWaitAttr) 752 exitDataOp.waitAttr(firOpBuilder.getUnitAttr()); 753 if (addFinalizeAttr) 754 exitDataOp.finalizeAttr(firOpBuilder.getUnitAttr()); 755 } 756 757 template <typename Op> 758 static void 759 genACCInitShutdownOp(Fortran::lower::AbstractConverter &converter, 760 const Fortran::parser::AccClauseList &accClauseList) { 761 mlir::Value ifCond, deviceNum; 762 SmallVector<mlir::Value> deviceTypeOperands; 763 764 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 765 mlir::Location currentLocation = converter.getCurrentLocation(); 766 Fortran::lower::StatementContext stmtCtx; 767 768 // Lower clauses values mapped to operands. 769 // Keep track of each group of operands separatly as clauses can appear 770 // more than once. 771 for (const Fortran::parser::AccClause &clause : accClauseList.v) { 772 if (const auto *ifClause = 773 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) { 774 genIfClause(converter, ifClause, ifCond, stmtCtx); 775 } else if (const auto *deviceNumClause = 776 std::get_if<Fortran::parser::AccClause::DeviceNum>( 777 &clause.u)) { 778 deviceNum = fir::getBase(converter.genExprValue( 779 *Fortran::semantics::GetExpr(deviceNumClause->v), stmtCtx)); 780 } else if (const auto *deviceTypeClause = 781 std::get_if<Fortran::parser::AccClause::DeviceType>( 782 &clause.u)) { 783 genDeviceTypeClause(converter, deviceTypeClause, deviceTypeOperands, 784 stmtCtx); 785 } 786 } 787 788 // Prepare the operand segement size attribute and the operands value range. 789 SmallVector<mlir::Value, 6> operands; 790 SmallVector<int32_t, 3> operandSegments; 791 addOperands(operands, operandSegments, deviceTypeOperands); 792 addOperand(operands, operandSegments, deviceNum); 793 addOperand(operands, operandSegments, ifCond); 794 795 createSimpleOp<Op>(firOpBuilder, currentLocation, operands, operandSegments); 796 } 797 798 static void 799 genACCUpdateOp(Fortran::lower::AbstractConverter &converter, 800 const Fortran::parser::AccClauseList &accClauseList) { 801 mlir::Value ifCond, async, waitDevnum; 802 SmallVector<mlir::Value> hostOperands, deviceOperands, waitOperands, 803 deviceTypeOperands; 804 805 // Async and wait clause have optional values but can be present with 806 // no value as well. When there is no value, the op has an attribute to 807 // represent the clause. 808 bool addAsyncAttr = false; 809 bool addWaitAttr = false; 810 bool addIfPresentAttr = false; 811 812 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 813 mlir::Location currentLocation = converter.getCurrentLocation(); 814 Fortran::lower::StatementContext stmtCtx; 815 816 // Lower clauses values mapped to operands. 817 // Keep track of each group of operands separatly as clauses can appear 818 // more than once. 819 for (const Fortran::parser::AccClause &clause : accClauseList.v) { 820 if (const auto *ifClause = 821 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) { 822 genIfClause(converter, ifClause, ifCond, stmtCtx); 823 } else if (const auto *asyncClause = 824 std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) { 825 genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx); 826 } else if (const auto *waitClause = 827 std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) { 828 genWaitClause(converter, waitClause, waitOperands, waitDevnum, 829 addWaitAttr, stmtCtx); 830 } else if (const auto *deviceTypeClause = 831 std::get_if<Fortran::parser::AccClause::DeviceType>( 832 &clause.u)) { 833 genDeviceTypeClause(converter, deviceTypeClause, deviceTypeOperands, 834 stmtCtx); 835 } else if (const auto *hostClause = 836 std::get_if<Fortran::parser::AccClause::Host>(&clause.u)) { 837 genObjectList(hostClause->v, converter, hostOperands); 838 } else if (const auto *deviceClause = 839 std::get_if<Fortran::parser::AccClause::Device>(&clause.u)) { 840 genObjectList(deviceClause->v, converter, deviceOperands); 841 } 842 } 843 844 // Prepare the operand segement size attribute and the operands value range. 845 SmallVector<mlir::Value> operands; 846 SmallVector<int32_t> operandSegments; 847 addOperand(operands, operandSegments, ifCond); 848 addOperand(operands, operandSegments, async); 849 addOperand(operands, operandSegments, waitDevnum); 850 addOperands(operands, operandSegments, waitOperands); 851 addOperands(operands, operandSegments, deviceTypeOperands); 852 addOperands(operands, operandSegments, hostOperands); 853 addOperands(operands, operandSegments, deviceOperands); 854 855 mlir::acc::UpdateOp updateOp = createSimpleOp<mlir::acc::UpdateOp>( 856 firOpBuilder, currentLocation, operands, operandSegments); 857 858 if (addAsyncAttr) 859 updateOp.asyncAttr(firOpBuilder.getUnitAttr()); 860 if (addWaitAttr) 861 updateOp.waitAttr(firOpBuilder.getUnitAttr()); 862 if (addIfPresentAttr) 863 updateOp.ifPresentAttr(firOpBuilder.getUnitAttr()); 864 } 865 866 static void 867 genACC(Fortran::lower::AbstractConverter &converter, 868 Fortran::lower::pft::Evaluation &eval, 869 const Fortran::parser::OpenACCStandaloneConstruct &standaloneConstruct) { 870 const auto &standaloneDirective = 871 std::get<Fortran::parser::AccStandaloneDirective>(standaloneConstruct.t); 872 const auto &accClauseList = 873 std::get<Fortran::parser::AccClauseList>(standaloneConstruct.t); 874 875 if (standaloneDirective.v == llvm::acc::Directive::ACCD_enter_data) { 876 genACCEnterDataOp(converter, accClauseList); 877 } else if (standaloneDirective.v == llvm::acc::Directive::ACCD_exit_data) { 878 genACCExitDataOp(converter, accClauseList); 879 } else if (standaloneDirective.v == llvm::acc::Directive::ACCD_init) { 880 genACCInitShutdownOp<mlir::acc::InitOp>(converter, accClauseList); 881 } else if (standaloneDirective.v == llvm::acc::Directive::ACCD_shutdown) { 882 genACCInitShutdownOp<mlir::acc::ShutdownOp>(converter, accClauseList); 883 } else if (standaloneDirective.v == llvm::acc::Directive::ACCD_set) { 884 TODO(converter.getCurrentLocation(), 885 "OpenACC set directive not lowered yet!"); 886 } else if (standaloneDirective.v == llvm::acc::Directive::ACCD_update) { 887 genACCUpdateOp(converter, accClauseList); 888 } 889 } 890 891 static void genACC(Fortran::lower::AbstractConverter &converter, 892 Fortran::lower::pft::Evaluation &eval, 893 const Fortran::parser::OpenACCWaitConstruct &waitConstruct) { 894 895 const auto &waitArgument = 896 std::get<std::optional<Fortran::parser::AccWaitArgument>>( 897 waitConstruct.t); 898 const auto &accClauseList = 899 std::get<Fortran::parser::AccClauseList>(waitConstruct.t); 900 901 mlir::Value ifCond, waitDevnum, async; 902 SmallVector<mlir::Value> waitOperands; 903 904 // Async clause have optional values but can be present with 905 // no value as well. When there is no value, the op has an attribute to 906 // represent the clause. 907 bool addAsyncAttr = false; 908 909 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); 910 mlir::Location currentLocation = converter.getCurrentLocation(); 911 Fortran::lower::StatementContext stmtCtx; 912 913 if (waitArgument) { // wait has a value. 914 const Fortran::parser::AccWaitArgument &waitArg = *waitArgument; 915 const std::list<Fortran::parser::ScalarIntExpr> &waitList = 916 std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t); 917 for (const Fortran::parser::ScalarIntExpr &value : waitList) { 918 mlir::Value v = fir::getBase( 919 converter.genExprValue(*Fortran::semantics::GetExpr(value), stmtCtx)); 920 waitOperands.push_back(v); 921 } 922 923 const std::optional<Fortran::parser::ScalarIntExpr> &waitDevnumValue = 924 std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t); 925 if (waitDevnumValue) 926 waitDevnum = fir::getBase(converter.genExprValue( 927 *Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx)); 928 } 929 930 // Lower clauses values mapped to operands. 931 // Keep track of each group of operands separatly as clauses can appear 932 // more than once. 933 for (const Fortran::parser::AccClause &clause : accClauseList.v) { 934 if (const auto *ifClause = 935 std::get_if<Fortran::parser::AccClause::If>(&clause.u)) { 936 genIfClause(converter, ifClause, ifCond, stmtCtx); 937 } else if (const auto *asyncClause = 938 std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) { 939 genAsyncClause(converter, asyncClause, async, addAsyncAttr, stmtCtx); 940 } 941 } 942 943 // Prepare the operand segement size attribute and the operands value range. 944 SmallVector<mlir::Value> operands; 945 SmallVector<int32_t> operandSegments; 946 addOperands(operands, operandSegments, waitOperands); 947 addOperand(operands, operandSegments, async); 948 addOperand(operands, operandSegments, waitDevnum); 949 addOperand(operands, operandSegments, ifCond); 950 951 mlir::acc::WaitOp waitOp = createSimpleOp<mlir::acc::WaitOp>( 952 firOpBuilder, currentLocation, operands, operandSegments); 953 954 if (addAsyncAttr) 955 waitOp.asyncAttr(firOpBuilder.getUnitAttr()); 956 } 957 958 void Fortran::lower::genOpenACCConstruct( 959 Fortran::lower::AbstractConverter &converter, 960 Fortran::lower::pft::Evaluation &eval, 961 const Fortran::parser::OpenACCConstruct &accConstruct) { 962 963 std::visit( 964 common::visitors{ 965 [&](const Fortran::parser::OpenACCBlockConstruct &blockConstruct) { 966 genACC(converter, eval, blockConstruct); 967 }, 968 [&](const Fortran::parser::OpenACCCombinedConstruct 969 &combinedConstruct) { 970 TODO(converter.getCurrentLocation(), 971 "OpenACC Combined construct not lowered yet!"); 972 }, 973 [&](const Fortran::parser::OpenACCLoopConstruct &loopConstruct) { 974 genACC(converter, eval, loopConstruct); 975 }, 976 [&](const Fortran::parser::OpenACCStandaloneConstruct 977 &standaloneConstruct) { 978 genACC(converter, eval, standaloneConstruct); 979 }, 980 [&](const Fortran::parser::OpenACCRoutineConstruct 981 &routineConstruct) { 982 TODO(converter.getCurrentLocation(), 983 "OpenACC Routine construct not lowered yet!"); 984 }, 985 [&](const Fortran::parser::OpenACCCacheConstruct &cacheConstruct) { 986 TODO(converter.getCurrentLocation(), 987 "OpenACC Cache construct not lowered yet!"); 988 }, 989 [&](const Fortran::parser::OpenACCWaitConstruct &waitConstruct) { 990 genACC(converter, eval, waitConstruct); 991 }, 992 [&](const Fortran::parser::OpenACCAtomicConstruct &atomicConstruct) { 993 TODO(converter.getCurrentLocation(), 994 "OpenACC Atomic construct not lowered yet!"); 995 }, 996 }, 997 accConstruct.u); 998 } 999