1 //===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===// 2 // 3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 // ============================================================================= 8 9 #include "mlir/Dialect/OpenACC/OpenACC.h" 10 #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc" 11 #include "mlir/IR/Builders.h" 12 #include "mlir/IR/BuiltinTypes.h" 13 #include "mlir/IR/DialectImplementation.h" 14 #include "mlir/IR/Matchers.h" 15 #include "mlir/IR/OpImplementation.h" 16 #include "mlir/Transforms/DialectConversion.h" 17 #include "llvm/ADT/TypeSwitch.h" 18 19 using namespace mlir; 20 using namespace acc; 21 22 #include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc" 23 24 //===----------------------------------------------------------------------===// 25 // OpenACC operations 26 //===----------------------------------------------------------------------===// 27 28 void OpenACCDialect::initialize() { 29 addOperations< 30 #define GET_OP_LIST 31 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc" 32 >(); 33 addAttributes< 34 #define GET_ATTRDEF_LIST 35 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc" 36 >(); 37 } 38 39 template <typename StructureOp> 40 static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, 41 unsigned nRegions = 1) { 42 43 SmallVector<Region *, 2> regions; 44 for (unsigned i = 0; i < nRegions; ++i) 45 regions.push_back(state.addRegion()); 46 47 for (Region *region : regions) { 48 if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{})) 49 return failure(); 50 } 51 52 return success(); 53 } 54 55 static ParseResult 56 parseOperandList(OpAsmParser &parser, StringRef keyword, 57 SmallVectorImpl<OpAsmParser::OperandType> &args, 58 SmallVectorImpl<Type> &argTypes, OperationState &result) { 59 if (failed(parser.parseOptionalKeyword(keyword))) 60 return success(); 61 62 if (failed(parser.parseLParen())) 63 return failure(); 64 65 // Exit early if the list is empty. 66 if (succeeded(parser.parseOptionalRParen())) 67 return success(); 68 69 do { 70 OpAsmParser::OperandType arg; 71 Type type; 72 73 if (parser.parseRegionArgument(arg) || parser.parseColonType(type)) 74 return failure(); 75 76 args.push_back(arg); 77 argTypes.push_back(type); 78 } while (succeeded(parser.parseOptionalComma())); 79 80 if (failed(parser.parseRParen())) 81 return failure(); 82 83 return parser.resolveOperands(args, argTypes, parser.getCurrentLocation(), 84 result.operands); 85 } 86 87 static void printOperandList(Operation::operand_range operands, 88 StringRef listName, OpAsmPrinter &printer) { 89 90 if (!operands.empty()) { 91 printer << " " << listName << "("; 92 llvm::interleaveComma(operands, printer, [&](Value op) { 93 printer << op << ": " << op.getType(); 94 }); 95 printer << ")"; 96 } 97 } 98 99 static ParseResult parseOptionalOperand(OpAsmParser &parser, StringRef keyword, 100 OpAsmParser::OperandType &operand, 101 Type type, bool &hasOptional, 102 OperationState &result) { 103 hasOptional = false; 104 if (succeeded(parser.parseOptionalKeyword(keyword))) { 105 hasOptional = true; 106 if (parser.parseLParen() || parser.parseOperand(operand) || 107 parser.resolveOperand(operand, type, result.operands) || 108 parser.parseRParen()) 109 return failure(); 110 } 111 return success(); 112 } 113 114 static ParseResult parseOperandAndType(OpAsmParser &parser, 115 OperationState &result) { 116 OpAsmParser::OperandType operand; 117 Type type; 118 if (parser.parseOperand(operand) || parser.parseColonType(type) || 119 parser.resolveOperand(operand, type, result.operands)) 120 return failure(); 121 return success(); 122 } 123 124 /// Parse optional operand and its type wrapped in parenthesis prefixed with 125 /// a keyword. 126 /// Example: 127 /// keyword `(` %vectorLength: i64 `)` 128 static OptionalParseResult parseOptionalOperandAndType(OpAsmParser &parser, 129 StringRef keyword, 130 OperationState &result) { 131 OpAsmParser::OperandType operand; 132 if (succeeded(parser.parseOptionalKeyword(keyword))) { 133 return failure(parser.parseLParen() || 134 parseOperandAndType(parser, result) || parser.parseRParen()); 135 } 136 return llvm::None; 137 } 138 139 /// Parse optional operand and its type wrapped in parenthesis. 140 /// Example: 141 /// `(` %vectorLength: i64 `)` 142 static OptionalParseResult parseOptionalOperandAndType(OpAsmParser &parser, 143 OperationState &result) { 144 if (succeeded(parser.parseOptionalLParen())) { 145 return failure(parseOperandAndType(parser, result) || parser.parseRParen()); 146 } 147 return llvm::None; 148 } 149 150 /// Parse optional operand with its type prefixed with prefixKeyword `=`. 151 /// Example: 152 /// num=%gangNum: i32 153 static OptionalParseResult parserOptionalOperandAndTypeWithPrefix( 154 OpAsmParser &parser, OperationState &result, StringRef prefixKeyword) { 155 if (succeeded(parser.parseOptionalKeyword(prefixKeyword))) { 156 parser.parseEqual(); 157 return parseOperandAndType(parser, result); 158 } 159 return llvm::None; 160 } 161 162 static bool isComputeOperation(Operation *op) { 163 return isa<acc::ParallelOp>(op) || isa<acc::LoopOp>(op); 164 } 165 166 namespace { 167 /// Pattern to remove operation without region that have constant false `ifCond` 168 /// and remove the condition from the operation if the `ifCond` is a true 169 /// constant. 170 template <typename OpTy> 171 struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> { 172 using OpRewritePattern<OpTy>::OpRewritePattern; 173 174 LogicalResult matchAndRewrite(OpTy op, 175 PatternRewriter &rewriter) const override { 176 // Early return if there is no condition. 177 Value ifCond = op.ifCond(); 178 if (!ifCond) 179 return success(); 180 181 IntegerAttr constAttr; 182 if (matchPattern(ifCond, m_Constant(&constAttr))) { 183 if (constAttr.getInt()) 184 rewriter.updateRootInPlace(op, [&]() { op.ifCondMutable().erase(0); }); 185 else 186 rewriter.eraseOp(op); 187 } 188 189 return success(); 190 } 191 }; 192 } // namespace 193 194 //===----------------------------------------------------------------------===// 195 // ParallelOp 196 //===----------------------------------------------------------------------===// 197 198 /// Parse acc.parallel operation 199 /// operation := `acc.parallel` `async` `(` index `)`? 200 /// `wait` `(` index-list `)`? 201 /// `num_gangs` `(` value `)`? 202 /// `num_workers` `(` value `)`? 203 /// `vector_length` `(` value `)`? 204 /// `if` `(` value `)`? 205 /// `self` `(` value `)`? 206 /// `reduction` `(` value-list `)`? 207 /// `copy` `(` value-list `)`? 208 /// `copyin` `(` value-list `)`? 209 /// `copyin_readonly` `(` value-list `)`? 210 /// `copyout` `(` value-list `)`? 211 /// `copyout_zero` `(` value-list `)`? 212 /// `create` `(` value-list `)`? 213 /// `create_zero` `(` value-list `)`? 214 /// `no_create` `(` value-list `)`? 215 /// `present` `(` value-list `)`? 216 /// `deviceptr` `(` value-list `)`? 217 /// `attach` `(` value-list `)`? 218 /// `private` `(` value-list `)`? 219 /// `firstprivate` `(` value-list `)`? 220 /// region attr-dict? 221 ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) { 222 Builder &builder = parser.getBuilder(); 223 SmallVector<OpAsmParser::OperandType, 8> privateOperands, 224 firstprivateOperands, copyOperands, copyinOperands, 225 copyinReadonlyOperands, copyoutOperands, copyoutZeroOperands, 226 createOperands, createZeroOperands, noCreateOperands, presentOperands, 227 devicePtrOperands, attachOperands, waitOperands, reductionOperands; 228 SmallVector<Type, 8> waitOperandTypes, reductionOperandTypes, 229 copyOperandTypes, copyinOperandTypes, copyinReadonlyOperandTypes, 230 copyoutOperandTypes, copyoutZeroOperandTypes, createOperandTypes, 231 createZeroOperandTypes, noCreateOperandTypes, presentOperandTypes, 232 deviceptrOperandTypes, attachOperandTypes, privateOperandTypes, 233 firstprivateOperandTypes; 234 235 SmallVector<Type, 8> operandTypes; 236 OpAsmParser::OperandType ifCond, selfCond; 237 bool hasIfCond = false, hasSelfCond = false; 238 OptionalParseResult async, numGangs, numWorkers, vectorLength; 239 Type i1Type = builder.getI1Type(); 240 241 // async()? 242 async = parseOptionalOperandAndType(parser, ParallelOp::getAsyncKeyword(), 243 result); 244 if (async.hasValue() && failed(*async)) 245 return failure(); 246 247 // wait()? 248 if (failed(parseOperandList(parser, ParallelOp::getWaitKeyword(), 249 waitOperands, waitOperandTypes, result))) 250 return failure(); 251 252 // num_gangs(value)? 253 numGangs = parseOptionalOperandAndType( 254 parser, ParallelOp::getNumGangsKeyword(), result); 255 if (numGangs.hasValue() && failed(*numGangs)) 256 return failure(); 257 258 // num_workers(value)? 259 numWorkers = parseOptionalOperandAndType( 260 parser, ParallelOp::getNumWorkersKeyword(), result); 261 if (numWorkers.hasValue() && failed(*numWorkers)) 262 return failure(); 263 264 // vector_length(value)? 265 vectorLength = parseOptionalOperandAndType( 266 parser, ParallelOp::getVectorLengthKeyword(), result); 267 if (vectorLength.hasValue() && failed(*vectorLength)) 268 return failure(); 269 270 // if()? 271 if (failed(parseOptionalOperand(parser, ParallelOp::getIfKeyword(), ifCond, 272 i1Type, hasIfCond, result))) 273 return failure(); 274 275 // self()? 276 if (failed(parseOptionalOperand(parser, ParallelOp::getSelfKeyword(), 277 selfCond, i1Type, hasSelfCond, result))) 278 return failure(); 279 280 // reduction()? 281 if (failed(parseOperandList(parser, ParallelOp::getReductionKeyword(), 282 reductionOperands, reductionOperandTypes, 283 result))) 284 return failure(); 285 286 // copy()? 287 if (failed(parseOperandList(parser, ParallelOp::getCopyKeyword(), 288 copyOperands, copyOperandTypes, result))) 289 return failure(); 290 291 // copyin()? 292 if (failed(parseOperandList(parser, ParallelOp::getCopyinKeyword(), 293 copyinOperands, copyinOperandTypes, result))) 294 return failure(); 295 296 // copyin_readonly()? 297 if (failed(parseOperandList(parser, ParallelOp::getCopyinReadonlyKeyword(), 298 copyinReadonlyOperands, 299 copyinReadonlyOperandTypes, result))) 300 return failure(); 301 302 // copyout()? 303 if (failed(parseOperandList(parser, ParallelOp::getCopyoutKeyword(), 304 copyoutOperands, copyoutOperandTypes, result))) 305 return failure(); 306 307 // copyout_zero()? 308 if (failed(parseOperandList(parser, ParallelOp::getCopyoutZeroKeyword(), 309 copyoutZeroOperands, copyoutZeroOperandTypes, 310 result))) 311 return failure(); 312 313 // create()? 314 if (failed(parseOperandList(parser, ParallelOp::getCreateKeyword(), 315 createOperands, createOperandTypes, result))) 316 return failure(); 317 318 // create_zero()? 319 if (failed(parseOperandList(parser, ParallelOp::getCreateZeroKeyword(), 320 createZeroOperands, createZeroOperandTypes, 321 result))) 322 return failure(); 323 324 // no_create()? 325 if (failed(parseOperandList(parser, ParallelOp::getNoCreateKeyword(), 326 noCreateOperands, noCreateOperandTypes, result))) 327 return failure(); 328 329 // present()? 330 if (failed(parseOperandList(parser, ParallelOp::getPresentKeyword(), 331 presentOperands, presentOperandTypes, result))) 332 return failure(); 333 334 // deviceptr()? 335 if (failed(parseOperandList(parser, ParallelOp::getDevicePtrKeyword(), 336 devicePtrOperands, deviceptrOperandTypes, 337 result))) 338 return failure(); 339 340 // attach()? 341 if (failed(parseOperandList(parser, ParallelOp::getAttachKeyword(), 342 attachOperands, attachOperandTypes, result))) 343 return failure(); 344 345 // private()? 346 if (failed(parseOperandList(parser, ParallelOp::getPrivateKeyword(), 347 privateOperands, privateOperandTypes, result))) 348 return failure(); 349 350 // firstprivate()? 351 if (failed(parseOperandList(parser, ParallelOp::getFirstPrivateKeyword(), 352 firstprivateOperands, firstprivateOperandTypes, 353 result))) 354 return failure(); 355 356 // Parallel op region 357 if (failed(parseRegions<ParallelOp>(parser, result))) 358 return failure(); 359 360 result.addAttribute( 361 ParallelOp::getOperandSegmentSizeAttr(), 362 builder.getI32VectorAttr( 363 {static_cast<int32_t>(async.hasValue() ? 1 : 0), 364 static_cast<int32_t>(waitOperands.size()), 365 static_cast<int32_t>(numGangs.hasValue() ? 1 : 0), 366 static_cast<int32_t>(numWorkers.hasValue() ? 1 : 0), 367 static_cast<int32_t>(vectorLength.hasValue() ? 1 : 0), 368 static_cast<int32_t>(hasIfCond ? 1 : 0), 369 static_cast<int32_t>(hasSelfCond ? 1 : 0), 370 static_cast<int32_t>(reductionOperands.size()), 371 static_cast<int32_t>(copyOperands.size()), 372 static_cast<int32_t>(copyinOperands.size()), 373 static_cast<int32_t>(copyinReadonlyOperands.size()), 374 static_cast<int32_t>(copyoutOperands.size()), 375 static_cast<int32_t>(copyoutZeroOperands.size()), 376 static_cast<int32_t>(createOperands.size()), 377 static_cast<int32_t>(createZeroOperands.size()), 378 static_cast<int32_t>(noCreateOperands.size()), 379 static_cast<int32_t>(presentOperands.size()), 380 static_cast<int32_t>(devicePtrOperands.size()), 381 static_cast<int32_t>(attachOperands.size()), 382 static_cast<int32_t>(privateOperands.size()), 383 static_cast<int32_t>(firstprivateOperands.size())})); 384 385 // Additional attributes 386 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) 387 return failure(); 388 389 return success(); 390 } 391 392 void ParallelOp::print(OpAsmPrinter &printer) { 393 // async()? 394 if (Value async = this->async()) 395 printer << " " << ParallelOp::getAsyncKeyword() << "(" << async << ": " 396 << async.getType() << ")"; 397 398 // wait()? 399 printOperandList(waitOperands(), ParallelOp::getWaitKeyword(), printer); 400 401 // num_gangs()? 402 if (Value numGangs = this->numGangs()) 403 printer << " " << ParallelOp::getNumGangsKeyword() << "(" << numGangs 404 << ": " << numGangs.getType() << ")"; 405 406 // num_workers()? 407 if (Value numWorkers = this->numWorkers()) 408 printer << " " << ParallelOp::getNumWorkersKeyword() << "(" << numWorkers 409 << ": " << numWorkers.getType() << ")"; 410 411 // vector_length()? 412 if (Value vectorLength = this->vectorLength()) 413 printer << " " << ParallelOp::getVectorLengthKeyword() << "(" 414 << vectorLength << ": " << vectorLength.getType() << ")"; 415 416 // if()? 417 if (Value ifCond = this->ifCond()) 418 printer << " " << ParallelOp::getIfKeyword() << "(" << ifCond << ")"; 419 420 // self()? 421 if (Value selfCond = this->selfCond()) 422 printer << " " << ParallelOp::getSelfKeyword() << "(" << selfCond << ")"; 423 424 // reduction()? 425 printOperandList(reductionOperands(), ParallelOp::getReductionKeyword(), 426 printer); 427 428 // copy()? 429 printOperandList(copyOperands(), ParallelOp::getCopyKeyword(), printer); 430 431 // copyin()? 432 printOperandList(copyinOperands(), ParallelOp::getCopyinKeyword(), printer); 433 434 // copyin_readonly()? 435 printOperandList(copyinReadonlyOperands(), 436 ParallelOp::getCopyinReadonlyKeyword(), printer); 437 438 // copyout()? 439 printOperandList(copyoutOperands(), ParallelOp::getCopyoutKeyword(), printer); 440 441 // copyout_zero()? 442 printOperandList(copyoutZeroOperands(), ParallelOp::getCopyoutZeroKeyword(), 443 printer); 444 445 // create()? 446 printOperandList(createOperands(), ParallelOp::getCreateKeyword(), printer); 447 448 // create_zero()? 449 printOperandList(createZeroOperands(), ParallelOp::getCreateZeroKeyword(), 450 printer); 451 452 // no_create()? 453 printOperandList(noCreateOperands(), ParallelOp::getNoCreateKeyword(), 454 printer); 455 456 // present()? 457 printOperandList(presentOperands(), ParallelOp::getPresentKeyword(), printer); 458 459 // deviceptr()? 460 printOperandList(devicePtrOperands(), ParallelOp::getDevicePtrKeyword(), 461 printer); 462 463 // attach()? 464 printOperandList(attachOperands(), ParallelOp::getAttachKeyword(), printer); 465 466 // private()? 467 printOperandList(gangPrivateOperands(), ParallelOp::getPrivateKeyword(), 468 printer); 469 470 // firstprivate()? 471 printOperandList(gangFirstPrivateOperands(), 472 ParallelOp::getFirstPrivateKeyword(), printer); 473 474 printer << ' '; 475 printer.printRegion(region(), 476 /*printEntryBlockArgs=*/false, 477 /*printBlockTerminators=*/true); 478 printer.printOptionalAttrDictWithKeyword( 479 (*this)->getAttrs(), ParallelOp::getOperandSegmentSizeAttr()); 480 } 481 482 unsigned ParallelOp::getNumDataOperands() { 483 return reductionOperands().size() + copyOperands().size() + 484 copyinOperands().size() + copyinReadonlyOperands().size() + 485 copyoutOperands().size() + copyoutZeroOperands().size() + 486 createOperands().size() + createZeroOperands().size() + 487 noCreateOperands().size() + presentOperands().size() + 488 devicePtrOperands().size() + attachOperands().size() + 489 gangPrivateOperands().size() + gangFirstPrivateOperands().size(); 490 } 491 492 Value ParallelOp::getDataOperand(unsigned i) { 493 unsigned numOptional = async() ? 1 : 0; 494 numOptional += numGangs() ? 1 : 0; 495 numOptional += numWorkers() ? 1 : 0; 496 numOptional += vectorLength() ? 1 : 0; 497 numOptional += ifCond() ? 1 : 0; 498 numOptional += selfCond() ? 1 : 0; 499 return getOperand(waitOperands().size() + numOptional + i); 500 } 501 502 //===----------------------------------------------------------------------===// 503 // LoopOp 504 //===----------------------------------------------------------------------===// 505 506 /// Parse acc.loop operation 507 /// operation := `acc.loop` 508 /// (`gang` ( `(` (`num=` value)? (`,` `static=` value `)`)? )? )? 509 /// (`vector` ( `(` value `)` )? )? (`worker` (`(` value `)`)? )? 510 /// (`vector_length` `(` value `)`)? 511 /// (`tile` `(` value-list `)`)? 512 /// (`private` `(` value-list `)`)? 513 /// (`reduction` `(` value-list `)`)? 514 /// region attr-dict? 515 ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) { 516 Builder &builder = parser.getBuilder(); 517 unsigned executionMapping = OpenACCExecMapping::NONE; 518 SmallVector<Type, 8> operandTypes; 519 SmallVector<OpAsmParser::OperandType, 8> privateOperands, reductionOperands; 520 SmallVector<OpAsmParser::OperandType, 8> tileOperands; 521 OptionalParseResult gangNum, gangStatic, worker, vector; 522 523 // gang? 524 if (succeeded(parser.parseOptionalKeyword(LoopOp::getGangKeyword()))) 525 executionMapping |= OpenACCExecMapping::GANG; 526 527 // optional gang operand 528 if (succeeded(parser.parseOptionalLParen())) { 529 gangNum = parserOptionalOperandAndTypeWithPrefix( 530 parser, result, LoopOp::getGangNumKeyword()); 531 if (gangNum.hasValue() && failed(*gangNum)) 532 return failure(); 533 parser.parseOptionalComma(); 534 gangStatic = parserOptionalOperandAndTypeWithPrefix( 535 parser, result, LoopOp::getGangStaticKeyword()); 536 if (gangStatic.hasValue() && failed(*gangStatic)) 537 return failure(); 538 parser.parseOptionalComma(); 539 if (failed(parser.parseRParen())) 540 return failure(); 541 } 542 543 // worker? 544 if (succeeded(parser.parseOptionalKeyword(LoopOp::getWorkerKeyword()))) 545 executionMapping |= OpenACCExecMapping::WORKER; 546 547 // optional worker operand 548 worker = parseOptionalOperandAndType(parser, result); 549 if (worker.hasValue() && failed(*worker)) 550 return failure(); 551 552 // vector? 553 if (succeeded(parser.parseOptionalKeyword(LoopOp::getVectorKeyword()))) 554 executionMapping |= OpenACCExecMapping::VECTOR; 555 556 // optional vector operand 557 vector = parseOptionalOperandAndType(parser, result); 558 if (vector.hasValue() && failed(*vector)) 559 return failure(); 560 561 // tile()? 562 if (failed(parseOperandList(parser, LoopOp::getTileKeyword(), tileOperands, 563 operandTypes, result))) 564 return failure(); 565 566 // private()? 567 if (failed(parseOperandList(parser, LoopOp::getPrivateKeyword(), 568 privateOperands, operandTypes, result))) 569 return failure(); 570 571 // reduction()? 572 if (failed(parseOperandList(parser, LoopOp::getReductionKeyword(), 573 reductionOperands, operandTypes, result))) 574 return failure(); 575 576 if (executionMapping != acc::OpenACCExecMapping::NONE) 577 result.addAttribute(LoopOp::getExecutionMappingAttrName(), 578 builder.getI64IntegerAttr(executionMapping)); 579 580 // Parse optional results in case there is a reduce. 581 if (parser.parseOptionalArrowTypeList(result.types)) 582 return failure(); 583 584 if (failed(parseRegions<LoopOp>(parser, result))) 585 return failure(); 586 587 result.addAttribute(LoopOp::getOperandSegmentSizeAttr(), 588 builder.getI32VectorAttr( 589 {static_cast<int32_t>(gangNum.hasValue() ? 1 : 0), 590 static_cast<int32_t>(gangStatic.hasValue() ? 1 : 0), 591 static_cast<int32_t>(worker.hasValue() ? 1 : 0), 592 static_cast<int32_t>(vector.hasValue() ? 1 : 0), 593 static_cast<int32_t>(tileOperands.size()), 594 static_cast<int32_t>(privateOperands.size()), 595 static_cast<int32_t>(reductionOperands.size())})); 596 597 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) 598 return failure(); 599 600 return success(); 601 } 602 603 void LoopOp::print(OpAsmPrinter &printer) { 604 unsigned execMapping = exec_mapping(); 605 if (execMapping & OpenACCExecMapping::GANG) { 606 printer << " " << LoopOp::getGangKeyword(); 607 Value gangNum = this->gangNum(); 608 Value gangStatic = this->gangStatic(); 609 610 // Print optional gang operands 611 if (gangNum || gangStatic) { 612 printer << "("; 613 if (gangNum) { 614 printer << LoopOp::getGangNumKeyword() << "=" << gangNum << ": " 615 << gangNum.getType(); 616 if (gangStatic) 617 printer << ", "; 618 } 619 if (gangStatic) 620 printer << LoopOp::getGangStaticKeyword() << "=" << gangStatic << ": " 621 << gangStatic.getType(); 622 printer << ")"; 623 } 624 } 625 626 if (execMapping & OpenACCExecMapping::WORKER) { 627 printer << " " << LoopOp::getWorkerKeyword(); 628 629 // Print optional worker operand if present 630 if (Value workerNum = this->workerNum()) 631 printer << "(" << workerNum << ": " << workerNum.getType() << ")"; 632 } 633 634 if (execMapping & OpenACCExecMapping::VECTOR) { 635 printer << " " << LoopOp::getVectorKeyword(); 636 637 // Print optional vector operand if present 638 if (Value vectorLength = this->vectorLength()) 639 printer << "(" << vectorLength << ": " << vectorLength.getType() << ")"; 640 } 641 642 // tile()? 643 printOperandList(tileOperands(), LoopOp::getTileKeyword(), printer); 644 645 // private()? 646 printOperandList(privateOperands(), LoopOp::getPrivateKeyword(), printer); 647 648 // reduction()? 649 printOperandList(reductionOperands(), LoopOp::getReductionKeyword(), printer); 650 651 if (getNumResults() > 0) 652 printer << " -> (" << getResultTypes() << ")"; 653 654 printer << ' '; 655 printer.printRegion(region(), 656 /*printEntryBlockArgs=*/false, 657 /*printBlockTerminators=*/true); 658 659 printer.printOptionalAttrDictWithKeyword( 660 (*this)->getAttrs(), {LoopOp::getExecutionMappingAttrName(), 661 LoopOp::getOperandSegmentSizeAttr()}); 662 } 663 664 LogicalResult acc::LoopOp::verify() { 665 // auto, independent and seq attribute are mutually exclusive. 666 if ((auto_() && (independent() || seq())) || (independent() && seq())) { 667 return emitError("only one of " + acc::LoopOp::getAutoAttrName() + ", " + 668 acc::LoopOp::getIndependentAttrName() + ", " + 669 acc::LoopOp::getSeqAttrName() + 670 " can be present at the same time"); 671 } 672 673 // Gang, worker and vector are incompatible with seq. 674 if (seq() && exec_mapping() != OpenACCExecMapping::NONE) 675 return emitError("gang, worker or vector cannot appear with the seq attr"); 676 677 // Check non-empty body(). 678 if (region().empty()) 679 return emitError("expected non-empty body."); 680 681 return success(); 682 } 683 684 //===----------------------------------------------------------------------===// 685 // DataOp 686 //===----------------------------------------------------------------------===// 687 688 LogicalResult acc::DataOp::verify() { 689 // 2.6.5. Data Construct restriction 690 // At least one copy, copyin, copyout, create, no_create, present, deviceptr, 691 // attach, or default clause must appear on a data construct. 692 if (getOperands().empty() && !defaultAttr()) 693 return emitError("at least one operand or the default attribute " 694 "must appear on the data operation"); 695 return success(); 696 } 697 698 unsigned DataOp::getNumDataOperands() { 699 return copyOperands().size() + copyinOperands().size() + 700 copyinReadonlyOperands().size() + copyoutOperands().size() + 701 copyoutZeroOperands().size() + createOperands().size() + 702 createZeroOperands().size() + noCreateOperands().size() + 703 presentOperands().size() + deviceptrOperands().size() + 704 attachOperands().size(); 705 } 706 707 Value DataOp::getDataOperand(unsigned i) { 708 unsigned numOptional = ifCond() ? 1 : 0; 709 return getOperand(numOptional + i); 710 } 711 712 //===----------------------------------------------------------------------===// 713 // ExitDataOp 714 //===----------------------------------------------------------------------===// 715 716 LogicalResult acc::ExitDataOp::verify() { 717 // 2.6.6. Data Exit Directive restriction 718 // At least one copyout, delete, or detach clause must appear on an exit data 719 // directive. 720 if (copyoutOperands().empty() && deleteOperands().empty() && 721 detachOperands().empty()) 722 return emitError( 723 "at least one operand in copyout, delete or detach must appear on the " 724 "exit data operation"); 725 726 // The async attribute represent the async clause without value. Therefore the 727 // attribute and operand cannot appear at the same time. 728 if (asyncOperand() && async()) 729 return emitError("async attribute cannot appear with asyncOperand"); 730 731 // The wait attribute represent the wait clause without values. Therefore the 732 // attribute and operands cannot appear at the same time. 733 if (!waitOperands().empty() && wait()) 734 return emitError("wait attribute cannot appear with waitOperands"); 735 736 if (waitDevnum() && waitOperands().empty()) 737 return emitError("wait_devnum cannot appear without waitOperands"); 738 739 return success(); 740 } 741 742 unsigned ExitDataOp::getNumDataOperands() { 743 return copyoutOperands().size() + deleteOperands().size() + 744 detachOperands().size(); 745 } 746 747 Value ExitDataOp::getDataOperand(unsigned i) { 748 unsigned numOptional = ifCond() ? 1 : 0; 749 numOptional += asyncOperand() ? 1 : 0; 750 numOptional += waitDevnum() ? 1 : 0; 751 return getOperand(waitOperands().size() + numOptional + i); 752 } 753 754 void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results, 755 MLIRContext *context) { 756 results.add<RemoveConstantIfCondition<ExitDataOp>>(context); 757 } 758 759 //===----------------------------------------------------------------------===// 760 // EnterDataOp 761 //===----------------------------------------------------------------------===// 762 763 LogicalResult acc::EnterDataOp::verify() { 764 // 2.6.6. Data Enter Directive restriction 765 // At least one copyin, create, or attach clause must appear on an enter data 766 // directive. 767 if (copyinOperands().empty() && createOperands().empty() && 768 createZeroOperands().empty() && attachOperands().empty()) 769 return emitError( 770 "at least one operand in copyin, create, " 771 "create_zero or attach must appear on the enter data operation"); 772 773 // The async attribute represent the async clause without value. Therefore the 774 // attribute and operand cannot appear at the same time. 775 if (asyncOperand() && async()) 776 return emitError("async attribute cannot appear with asyncOperand"); 777 778 // The wait attribute represent the wait clause without values. Therefore the 779 // attribute and operands cannot appear at the same time. 780 if (!waitOperands().empty() && wait()) 781 return emitError("wait attribute cannot appear with waitOperands"); 782 783 if (waitDevnum() && waitOperands().empty()) 784 return emitError("wait_devnum cannot appear without waitOperands"); 785 786 return success(); 787 } 788 789 unsigned EnterDataOp::getNumDataOperands() { 790 return copyinOperands().size() + createOperands().size() + 791 createZeroOperands().size() + attachOperands().size(); 792 } 793 794 Value EnterDataOp::getDataOperand(unsigned i) { 795 unsigned numOptional = ifCond() ? 1 : 0; 796 numOptional += asyncOperand() ? 1 : 0; 797 numOptional += waitDevnum() ? 1 : 0; 798 return getOperand(waitOperands().size() + numOptional + i); 799 } 800 801 void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results, 802 MLIRContext *context) { 803 results.add<RemoveConstantIfCondition<EnterDataOp>>(context); 804 } 805 806 //===----------------------------------------------------------------------===// 807 // InitOp 808 //===----------------------------------------------------------------------===// 809 810 LogicalResult acc::InitOp::verify() { 811 Operation *currOp = *this; 812 while ((currOp = currOp->getParentOp())) 813 if (isComputeOperation(currOp)) 814 return emitOpError("cannot be nested in a compute operation"); 815 return success(); 816 } 817 818 //===----------------------------------------------------------------------===// 819 // ShutdownOp 820 //===----------------------------------------------------------------------===// 821 822 LogicalResult acc::ShutdownOp::verify() { 823 Operation *currOp = *this; 824 while ((currOp = currOp->getParentOp())) 825 if (isComputeOperation(currOp)) 826 return emitOpError("cannot be nested in a compute operation"); 827 return success(); 828 } 829 830 //===----------------------------------------------------------------------===// 831 // UpdateOp 832 //===----------------------------------------------------------------------===// 833 834 LogicalResult acc::UpdateOp::verify() { 835 // At least one of host or device should have a value. 836 if (hostOperands().empty() && deviceOperands().empty()) 837 return emitError( 838 "at least one value must be present in hostOperands or deviceOperands"); 839 840 // The async attribute represent the async clause without value. Therefore the 841 // attribute and operand cannot appear at the same time. 842 if (asyncOperand() && async()) 843 return emitError("async attribute cannot appear with asyncOperand"); 844 845 // The wait attribute represent the wait clause without values. Therefore the 846 // attribute and operands cannot appear at the same time. 847 if (!waitOperands().empty() && wait()) 848 return emitError("wait attribute cannot appear with waitOperands"); 849 850 if (waitDevnum() && waitOperands().empty()) 851 return emitError("wait_devnum cannot appear without waitOperands"); 852 853 return success(); 854 } 855 856 unsigned UpdateOp::getNumDataOperands() { 857 return hostOperands().size() + deviceOperands().size(); 858 } 859 860 Value UpdateOp::getDataOperand(unsigned i) { 861 unsigned numOptional = asyncOperand() ? 1 : 0; 862 numOptional += waitDevnum() ? 1 : 0; 863 numOptional += ifCond() ? 1 : 0; 864 return getOperand(waitOperands().size() + deviceTypeOperands().size() + 865 numOptional + i); 866 } 867 868 void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results, 869 MLIRContext *context) { 870 results.add<RemoveConstantIfCondition<UpdateOp>>(context); 871 } 872 873 //===----------------------------------------------------------------------===// 874 // WaitOp 875 //===----------------------------------------------------------------------===// 876 877 LogicalResult acc::WaitOp::verify() { 878 // The async attribute represent the async clause without value. Therefore the 879 // attribute and operand cannot appear at the same time. 880 if (asyncOperand() && async()) 881 return emitError("async attribute cannot appear with asyncOperand"); 882 883 if (waitDevnum() && waitOperands().empty()) 884 return emitError("wait_devnum cannot appear without waitOperands"); 885 886 return success(); 887 } 888 889 #define GET_OP_CLASSES 890 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc" 891 892 #define GET_ATTRDEF_CLASSES 893 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc" 894