1 //===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===// 2 // 3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 // ============================================================================= 8 9 #include "mlir/Dialect/OpenACC/OpenACC.h" 10 #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc" 11 #include "mlir/IR/Builders.h" 12 #include "mlir/IR/BuiltinTypes.h" 13 #include "mlir/IR/DialectImplementation.h" 14 #include "mlir/IR/Matchers.h" 15 #include "mlir/IR/OpImplementation.h" 16 #include "mlir/Transforms/DialectConversion.h" 17 #include "llvm/ADT/TypeSwitch.h" 18 19 using namespace mlir; 20 using namespace acc; 21 22 #include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc" 23 24 //===----------------------------------------------------------------------===// 25 // OpenACC operations 26 //===----------------------------------------------------------------------===// 27 28 void OpenACCDialect::initialize() { 29 addOperations< 30 #define GET_OP_LIST 31 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc" 32 >(); 33 addAttributes< 34 #define GET_ATTRDEF_LIST 35 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc" 36 >(); 37 } 38 39 template <typename StructureOp> 40 static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, 41 unsigned nRegions = 1) { 42 43 SmallVector<Region *, 2> regions; 44 for (unsigned i = 0; i < nRegions; ++i) 45 regions.push_back(state.addRegion()); 46 47 for (Region *region : regions) { 48 if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{})) 49 return failure(); 50 } 51 52 return success(); 53 } 54 55 static ParseResult 56 parseOperandList(OpAsmParser &parser, StringRef keyword, 57 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &args, 58 SmallVectorImpl<Type> &argTypes, OperationState &result) { 59 if (failed(parser.parseOptionalKeyword(keyword))) 60 return success(); 61 62 if (failed(parser.parseLParen())) 63 return failure(); 64 65 // Exit early if the list is empty. 66 if (succeeded(parser.parseOptionalRParen())) 67 return success(); 68 69 do { 70 OpAsmParser::UnresolvedOperand arg; 71 Type type; 72 73 if (parser.parseOperand(arg, /*allowResultNumber=*/false) || 74 parser.parseColonType(type)) 75 return failure(); 76 77 args.push_back(arg); 78 argTypes.push_back(type); 79 } while (succeeded(parser.parseOptionalComma())); 80 81 if (failed(parser.parseRParen())) 82 return failure(); 83 84 return parser.resolveOperands(args, argTypes, parser.getCurrentLocation(), 85 result.operands); 86 } 87 88 static void printOperandList(Operation::operand_range operands, 89 StringRef listName, OpAsmPrinter &printer) { 90 91 if (!operands.empty()) { 92 printer << " " << listName << "("; 93 llvm::interleaveComma(operands, printer, [&](Value op) { 94 printer << op << ": " << op.getType(); 95 }); 96 printer << ")"; 97 } 98 } 99 100 static ParseResult parseOptionalOperand(OpAsmParser &parser, StringRef keyword, 101 OpAsmParser::UnresolvedOperand &operand, 102 Type type, bool &hasOptional, 103 OperationState &result) { 104 hasOptional = false; 105 if (succeeded(parser.parseOptionalKeyword(keyword))) { 106 hasOptional = true; 107 if (parser.parseLParen() || parser.parseOperand(operand) || 108 parser.resolveOperand(operand, type, result.operands) || 109 parser.parseRParen()) 110 return failure(); 111 } 112 return success(); 113 } 114 115 static ParseResult parseOperandAndType(OpAsmParser &parser, 116 OperationState &result) { 117 OpAsmParser::UnresolvedOperand operand; 118 Type type; 119 if (parser.parseOperand(operand) || parser.parseColonType(type) || 120 parser.resolveOperand(operand, type, result.operands)) 121 return failure(); 122 return success(); 123 } 124 125 /// Parse optional operand and its type wrapped in parenthesis prefixed with 126 /// a keyword. 127 /// Example: 128 /// keyword `(` %vectorLength: i64 `)` 129 static OptionalParseResult parseOptionalOperandAndType(OpAsmParser &parser, 130 StringRef keyword, 131 OperationState &result) { 132 OpAsmParser::UnresolvedOperand operand; 133 if (succeeded(parser.parseOptionalKeyword(keyword))) { 134 return failure(parser.parseLParen() || 135 parseOperandAndType(parser, result) || parser.parseRParen()); 136 } 137 return llvm::None; 138 } 139 140 /// Parse optional operand and its type wrapped in parenthesis. 141 /// Example: 142 /// `(` %vectorLength: i64 `)` 143 static OptionalParseResult parseOptionalOperandAndType(OpAsmParser &parser, 144 OperationState &result) { 145 if (succeeded(parser.parseOptionalLParen())) { 146 return failure(parseOperandAndType(parser, result) || parser.parseRParen()); 147 } 148 return llvm::None; 149 } 150 151 /// Parse optional operand with its type prefixed with prefixKeyword `=`. 152 /// Example: 153 /// num=%gangNum: i32 154 static OptionalParseResult parserOptionalOperandAndTypeWithPrefix( 155 OpAsmParser &parser, OperationState &result, StringRef prefixKeyword) { 156 if (succeeded(parser.parseOptionalKeyword(prefixKeyword))) { 157 parser.parseEqual(); 158 return parseOperandAndType(parser, result); 159 } 160 return llvm::None; 161 } 162 163 static bool isComputeOperation(Operation *op) { 164 return isa<acc::ParallelOp>(op) || isa<acc::LoopOp>(op); 165 } 166 167 namespace { 168 /// Pattern to remove operation without region that have constant false `ifCond` 169 /// and remove the condition from the operation if the `ifCond` is a true 170 /// constant. 171 template <typename OpTy> 172 struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> { 173 using OpRewritePattern<OpTy>::OpRewritePattern; 174 175 LogicalResult matchAndRewrite(OpTy op, 176 PatternRewriter &rewriter) const override { 177 // Early return if there is no condition. 178 Value ifCond = op.ifCond(); 179 if (!ifCond) 180 return success(); 181 182 IntegerAttr constAttr; 183 if (matchPattern(ifCond, m_Constant(&constAttr))) { 184 if (constAttr.getInt()) 185 rewriter.updateRootInPlace(op, [&]() { op.ifCondMutable().erase(0); }); 186 else 187 rewriter.eraseOp(op); 188 } 189 190 return success(); 191 } 192 }; 193 } // namespace 194 195 //===----------------------------------------------------------------------===// 196 // ParallelOp 197 //===----------------------------------------------------------------------===// 198 199 /// Parse acc.parallel operation 200 /// operation := `acc.parallel` `async` `(` index `)`? 201 /// `wait` `(` index-list `)`? 202 /// `num_gangs` `(` value `)`? 203 /// `num_workers` `(` value `)`? 204 /// `vector_length` `(` value `)`? 205 /// `if` `(` value `)`? 206 /// `self` `(` value `)`? 207 /// `reduction` `(` value-list `)`? 208 /// `copy` `(` value-list `)`? 209 /// `copyin` `(` value-list `)`? 210 /// `copyin_readonly` `(` value-list `)`? 211 /// `copyout` `(` value-list `)`? 212 /// `copyout_zero` `(` value-list `)`? 213 /// `create` `(` value-list `)`? 214 /// `create_zero` `(` value-list `)`? 215 /// `no_create` `(` value-list `)`? 216 /// `present` `(` value-list `)`? 217 /// `deviceptr` `(` value-list `)`? 218 /// `attach` `(` value-list `)`? 219 /// `private` `(` value-list `)`? 220 /// `firstprivate` `(` value-list `)`? 221 /// region attr-dict? 222 ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) { 223 Builder &builder = parser.getBuilder(); 224 SmallVector<OpAsmParser::UnresolvedOperand, 8> privateOperands, 225 firstprivateOperands, copyOperands, copyinOperands, 226 copyinReadonlyOperands, copyoutOperands, copyoutZeroOperands, 227 createOperands, createZeroOperands, noCreateOperands, presentOperands, 228 devicePtrOperands, attachOperands, waitOperands, reductionOperands; 229 SmallVector<Type, 8> waitOperandTypes, reductionOperandTypes, 230 copyOperandTypes, copyinOperandTypes, copyinReadonlyOperandTypes, 231 copyoutOperandTypes, copyoutZeroOperandTypes, createOperandTypes, 232 createZeroOperandTypes, noCreateOperandTypes, presentOperandTypes, 233 deviceptrOperandTypes, attachOperandTypes, privateOperandTypes, 234 firstprivateOperandTypes; 235 236 SmallVector<Type, 8> operandTypes; 237 OpAsmParser::UnresolvedOperand ifCond, selfCond; 238 bool hasIfCond = false, hasSelfCond = false; 239 OptionalParseResult async, numGangs, numWorkers, vectorLength; 240 Type i1Type = builder.getI1Type(); 241 242 // async()? 243 async = parseOptionalOperandAndType(parser, ParallelOp::getAsyncKeyword(), 244 result); 245 if (async.hasValue() && failed(*async)) 246 return failure(); 247 248 // wait()? 249 if (failed(parseOperandList(parser, ParallelOp::getWaitKeyword(), 250 waitOperands, waitOperandTypes, result))) 251 return failure(); 252 253 // num_gangs(value)? 254 numGangs = parseOptionalOperandAndType( 255 parser, ParallelOp::getNumGangsKeyword(), result); 256 if (numGangs.hasValue() && failed(*numGangs)) 257 return failure(); 258 259 // num_workers(value)? 260 numWorkers = parseOptionalOperandAndType( 261 parser, ParallelOp::getNumWorkersKeyword(), result); 262 if (numWorkers.hasValue() && failed(*numWorkers)) 263 return failure(); 264 265 // vector_length(value)? 266 vectorLength = parseOptionalOperandAndType( 267 parser, ParallelOp::getVectorLengthKeyword(), result); 268 if (vectorLength.hasValue() && failed(*vectorLength)) 269 return failure(); 270 271 // if()? 272 if (failed(parseOptionalOperand(parser, ParallelOp::getIfKeyword(), ifCond, 273 i1Type, hasIfCond, result))) 274 return failure(); 275 276 // self()? 277 if (failed(parseOptionalOperand(parser, ParallelOp::getSelfKeyword(), 278 selfCond, i1Type, hasSelfCond, result))) 279 return failure(); 280 281 // reduction()? 282 if (failed(parseOperandList(parser, ParallelOp::getReductionKeyword(), 283 reductionOperands, reductionOperandTypes, 284 result))) 285 return failure(); 286 287 // copy()? 288 if (failed(parseOperandList(parser, ParallelOp::getCopyKeyword(), 289 copyOperands, copyOperandTypes, result))) 290 return failure(); 291 292 // copyin()? 293 if (failed(parseOperandList(parser, ParallelOp::getCopyinKeyword(), 294 copyinOperands, copyinOperandTypes, result))) 295 return failure(); 296 297 // copyin_readonly()? 298 if (failed(parseOperandList(parser, ParallelOp::getCopyinReadonlyKeyword(), 299 copyinReadonlyOperands, 300 copyinReadonlyOperandTypes, result))) 301 return failure(); 302 303 // copyout()? 304 if (failed(parseOperandList(parser, ParallelOp::getCopyoutKeyword(), 305 copyoutOperands, copyoutOperandTypes, result))) 306 return failure(); 307 308 // copyout_zero()? 309 if (failed(parseOperandList(parser, ParallelOp::getCopyoutZeroKeyword(), 310 copyoutZeroOperands, copyoutZeroOperandTypes, 311 result))) 312 return failure(); 313 314 // create()? 315 if (failed(parseOperandList(parser, ParallelOp::getCreateKeyword(), 316 createOperands, createOperandTypes, result))) 317 return failure(); 318 319 // create_zero()? 320 if (failed(parseOperandList(parser, ParallelOp::getCreateZeroKeyword(), 321 createZeroOperands, createZeroOperandTypes, 322 result))) 323 return failure(); 324 325 // no_create()? 326 if (failed(parseOperandList(parser, ParallelOp::getNoCreateKeyword(), 327 noCreateOperands, noCreateOperandTypes, result))) 328 return failure(); 329 330 // present()? 331 if (failed(parseOperandList(parser, ParallelOp::getPresentKeyword(), 332 presentOperands, presentOperandTypes, result))) 333 return failure(); 334 335 // deviceptr()? 336 if (failed(parseOperandList(parser, ParallelOp::getDevicePtrKeyword(), 337 devicePtrOperands, deviceptrOperandTypes, 338 result))) 339 return failure(); 340 341 // attach()? 342 if (failed(parseOperandList(parser, ParallelOp::getAttachKeyword(), 343 attachOperands, attachOperandTypes, result))) 344 return failure(); 345 346 // private()? 347 if (failed(parseOperandList(parser, ParallelOp::getPrivateKeyword(), 348 privateOperands, privateOperandTypes, result))) 349 return failure(); 350 351 // firstprivate()? 352 if (failed(parseOperandList(parser, ParallelOp::getFirstPrivateKeyword(), 353 firstprivateOperands, firstprivateOperandTypes, 354 result))) 355 return failure(); 356 357 // Parallel op region 358 if (failed(parseRegions<ParallelOp>(parser, result))) 359 return failure(); 360 361 result.addAttribute( 362 ParallelOp::getOperandSegmentSizeAttr(), 363 builder.getI32VectorAttr( 364 {static_cast<int32_t>(async.hasValue() ? 1 : 0), 365 static_cast<int32_t>(waitOperands.size()), 366 static_cast<int32_t>(numGangs.hasValue() ? 1 : 0), 367 static_cast<int32_t>(numWorkers.hasValue() ? 1 : 0), 368 static_cast<int32_t>(vectorLength.hasValue() ? 1 : 0), 369 static_cast<int32_t>(hasIfCond ? 1 : 0), 370 static_cast<int32_t>(hasSelfCond ? 1 : 0), 371 static_cast<int32_t>(reductionOperands.size()), 372 static_cast<int32_t>(copyOperands.size()), 373 static_cast<int32_t>(copyinOperands.size()), 374 static_cast<int32_t>(copyinReadonlyOperands.size()), 375 static_cast<int32_t>(copyoutOperands.size()), 376 static_cast<int32_t>(copyoutZeroOperands.size()), 377 static_cast<int32_t>(createOperands.size()), 378 static_cast<int32_t>(createZeroOperands.size()), 379 static_cast<int32_t>(noCreateOperands.size()), 380 static_cast<int32_t>(presentOperands.size()), 381 static_cast<int32_t>(devicePtrOperands.size()), 382 static_cast<int32_t>(attachOperands.size()), 383 static_cast<int32_t>(privateOperands.size()), 384 static_cast<int32_t>(firstprivateOperands.size())})); 385 386 // Additional attributes 387 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) 388 return failure(); 389 390 return success(); 391 } 392 393 void ParallelOp::print(OpAsmPrinter &printer) { 394 // async()? 395 if (Value async = this->async()) 396 printer << " " << ParallelOp::getAsyncKeyword() << "(" << async << ": " 397 << async.getType() << ")"; 398 399 // wait()? 400 printOperandList(waitOperands(), ParallelOp::getWaitKeyword(), printer); 401 402 // num_gangs()? 403 if (Value numGangs = this->numGangs()) 404 printer << " " << ParallelOp::getNumGangsKeyword() << "(" << numGangs 405 << ": " << numGangs.getType() << ")"; 406 407 // num_workers()? 408 if (Value numWorkers = this->numWorkers()) 409 printer << " " << ParallelOp::getNumWorkersKeyword() << "(" << numWorkers 410 << ": " << numWorkers.getType() << ")"; 411 412 // vector_length()? 413 if (Value vectorLength = this->vectorLength()) 414 printer << " " << ParallelOp::getVectorLengthKeyword() << "(" 415 << vectorLength << ": " << vectorLength.getType() << ")"; 416 417 // if()? 418 if (Value ifCond = this->ifCond()) 419 printer << " " << ParallelOp::getIfKeyword() << "(" << ifCond << ")"; 420 421 // self()? 422 if (Value selfCond = this->selfCond()) 423 printer << " " << ParallelOp::getSelfKeyword() << "(" << selfCond << ")"; 424 425 // reduction()? 426 printOperandList(reductionOperands(), ParallelOp::getReductionKeyword(), 427 printer); 428 429 // copy()? 430 printOperandList(copyOperands(), ParallelOp::getCopyKeyword(), printer); 431 432 // copyin()? 433 printOperandList(copyinOperands(), ParallelOp::getCopyinKeyword(), printer); 434 435 // copyin_readonly()? 436 printOperandList(copyinReadonlyOperands(), 437 ParallelOp::getCopyinReadonlyKeyword(), printer); 438 439 // copyout()? 440 printOperandList(copyoutOperands(), ParallelOp::getCopyoutKeyword(), printer); 441 442 // copyout_zero()? 443 printOperandList(copyoutZeroOperands(), ParallelOp::getCopyoutZeroKeyword(), 444 printer); 445 446 // create()? 447 printOperandList(createOperands(), ParallelOp::getCreateKeyword(), printer); 448 449 // create_zero()? 450 printOperandList(createZeroOperands(), ParallelOp::getCreateZeroKeyword(), 451 printer); 452 453 // no_create()? 454 printOperandList(noCreateOperands(), ParallelOp::getNoCreateKeyword(), 455 printer); 456 457 // present()? 458 printOperandList(presentOperands(), ParallelOp::getPresentKeyword(), printer); 459 460 // deviceptr()? 461 printOperandList(devicePtrOperands(), ParallelOp::getDevicePtrKeyword(), 462 printer); 463 464 // attach()? 465 printOperandList(attachOperands(), ParallelOp::getAttachKeyword(), printer); 466 467 // private()? 468 printOperandList(gangPrivateOperands(), ParallelOp::getPrivateKeyword(), 469 printer); 470 471 // firstprivate()? 472 printOperandList(gangFirstPrivateOperands(), 473 ParallelOp::getFirstPrivateKeyword(), printer); 474 475 printer << ' '; 476 printer.printRegion(region(), 477 /*printEntryBlockArgs=*/false, 478 /*printBlockTerminators=*/true); 479 printer.printOptionalAttrDictWithKeyword( 480 (*this)->getAttrs(), ParallelOp::getOperandSegmentSizeAttr()); 481 } 482 483 unsigned ParallelOp::getNumDataOperands() { 484 return reductionOperands().size() + copyOperands().size() + 485 copyinOperands().size() + copyinReadonlyOperands().size() + 486 copyoutOperands().size() + copyoutZeroOperands().size() + 487 createOperands().size() + createZeroOperands().size() + 488 noCreateOperands().size() + presentOperands().size() + 489 devicePtrOperands().size() + attachOperands().size() + 490 gangPrivateOperands().size() + gangFirstPrivateOperands().size(); 491 } 492 493 Value ParallelOp::getDataOperand(unsigned i) { 494 unsigned numOptional = async() ? 1 : 0; 495 numOptional += numGangs() ? 1 : 0; 496 numOptional += numWorkers() ? 1 : 0; 497 numOptional += vectorLength() ? 1 : 0; 498 numOptional += ifCond() ? 1 : 0; 499 numOptional += selfCond() ? 1 : 0; 500 return getOperand(waitOperands().size() + numOptional + i); 501 } 502 503 //===----------------------------------------------------------------------===// 504 // LoopOp 505 //===----------------------------------------------------------------------===// 506 507 /// Parse acc.loop operation 508 /// operation := `acc.loop` 509 /// (`gang` ( `(` (`num=` value)? (`,` `static=` value `)`)? )? )? 510 /// (`vector` ( `(` value `)` )? )? (`worker` (`(` value `)`)? )? 511 /// (`vector_length` `(` value `)`)? 512 /// (`tile` `(` value-list `)`)? 513 /// (`private` `(` value-list `)`)? 514 /// (`reduction` `(` value-list `)`)? 515 /// region attr-dict? 516 ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) { 517 Builder &builder = parser.getBuilder(); 518 unsigned executionMapping = OpenACCExecMapping::NONE; 519 SmallVector<Type, 8> operandTypes; 520 SmallVector<OpAsmParser::UnresolvedOperand, 8> privateOperands, 521 reductionOperands; 522 SmallVector<OpAsmParser::UnresolvedOperand, 8> tileOperands; 523 OptionalParseResult gangNum, gangStatic, worker, vector; 524 525 // gang? 526 if (succeeded(parser.parseOptionalKeyword(LoopOp::getGangKeyword()))) 527 executionMapping |= OpenACCExecMapping::GANG; 528 529 // optional gang operand 530 if (succeeded(parser.parseOptionalLParen())) { 531 gangNum = parserOptionalOperandAndTypeWithPrefix( 532 parser, result, LoopOp::getGangNumKeyword()); 533 if (gangNum.hasValue() && failed(*gangNum)) 534 return failure(); 535 parser.parseOptionalComma(); 536 gangStatic = parserOptionalOperandAndTypeWithPrefix( 537 parser, result, LoopOp::getGangStaticKeyword()); 538 if (gangStatic.hasValue() && failed(*gangStatic)) 539 return failure(); 540 parser.parseOptionalComma(); 541 if (failed(parser.parseRParen())) 542 return failure(); 543 } 544 545 // worker? 546 if (succeeded(parser.parseOptionalKeyword(LoopOp::getWorkerKeyword()))) 547 executionMapping |= OpenACCExecMapping::WORKER; 548 549 // optional worker operand 550 worker = parseOptionalOperandAndType(parser, result); 551 if (worker.hasValue() && failed(*worker)) 552 return failure(); 553 554 // vector? 555 if (succeeded(parser.parseOptionalKeyword(LoopOp::getVectorKeyword()))) 556 executionMapping |= OpenACCExecMapping::VECTOR; 557 558 // optional vector operand 559 vector = parseOptionalOperandAndType(parser, result); 560 if (vector.hasValue() && failed(*vector)) 561 return failure(); 562 563 // tile()? 564 if (failed(parseOperandList(parser, LoopOp::getTileKeyword(), tileOperands, 565 operandTypes, result))) 566 return failure(); 567 568 // private()? 569 if (failed(parseOperandList(parser, LoopOp::getPrivateKeyword(), 570 privateOperands, operandTypes, result))) 571 return failure(); 572 573 // reduction()? 574 if (failed(parseOperandList(parser, LoopOp::getReductionKeyword(), 575 reductionOperands, operandTypes, result))) 576 return failure(); 577 578 if (executionMapping != acc::OpenACCExecMapping::NONE) 579 result.addAttribute(LoopOp::getExecutionMappingAttrName(), 580 builder.getI64IntegerAttr(executionMapping)); 581 582 // Parse optional results in case there is a reduce. 583 if (parser.parseOptionalArrowTypeList(result.types)) 584 return failure(); 585 586 if (failed(parseRegions<LoopOp>(parser, result))) 587 return failure(); 588 589 result.addAttribute(LoopOp::getOperandSegmentSizeAttr(), 590 builder.getI32VectorAttr( 591 {static_cast<int32_t>(gangNum.hasValue() ? 1 : 0), 592 static_cast<int32_t>(gangStatic.hasValue() ? 1 : 0), 593 static_cast<int32_t>(worker.hasValue() ? 1 : 0), 594 static_cast<int32_t>(vector.hasValue() ? 1 : 0), 595 static_cast<int32_t>(tileOperands.size()), 596 static_cast<int32_t>(privateOperands.size()), 597 static_cast<int32_t>(reductionOperands.size())})); 598 599 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) 600 return failure(); 601 602 return success(); 603 } 604 605 void LoopOp::print(OpAsmPrinter &printer) { 606 unsigned execMapping = exec_mapping(); 607 if (execMapping & OpenACCExecMapping::GANG) { 608 printer << " " << LoopOp::getGangKeyword(); 609 Value gangNum = this->gangNum(); 610 Value gangStatic = this->gangStatic(); 611 612 // Print optional gang operands 613 if (gangNum || gangStatic) { 614 printer << "("; 615 if (gangNum) { 616 printer << LoopOp::getGangNumKeyword() << "=" << gangNum << ": " 617 << gangNum.getType(); 618 if (gangStatic) 619 printer << ", "; 620 } 621 if (gangStatic) 622 printer << LoopOp::getGangStaticKeyword() << "=" << gangStatic << ": " 623 << gangStatic.getType(); 624 printer << ")"; 625 } 626 } 627 628 if (execMapping & OpenACCExecMapping::WORKER) { 629 printer << " " << LoopOp::getWorkerKeyword(); 630 631 // Print optional worker operand if present 632 if (Value workerNum = this->workerNum()) 633 printer << "(" << workerNum << ": " << workerNum.getType() << ")"; 634 } 635 636 if (execMapping & OpenACCExecMapping::VECTOR) { 637 printer << " " << LoopOp::getVectorKeyword(); 638 639 // Print optional vector operand if present 640 if (Value vectorLength = this->vectorLength()) 641 printer << "(" << vectorLength << ": " << vectorLength.getType() << ")"; 642 } 643 644 // tile()? 645 printOperandList(tileOperands(), LoopOp::getTileKeyword(), printer); 646 647 // private()? 648 printOperandList(privateOperands(), LoopOp::getPrivateKeyword(), printer); 649 650 // reduction()? 651 printOperandList(reductionOperands(), LoopOp::getReductionKeyword(), printer); 652 653 if (getNumResults() > 0) 654 printer << " -> (" << getResultTypes() << ")"; 655 656 printer << ' '; 657 printer.printRegion(region(), 658 /*printEntryBlockArgs=*/false, 659 /*printBlockTerminators=*/true); 660 661 printer.printOptionalAttrDictWithKeyword( 662 (*this)->getAttrs(), {LoopOp::getExecutionMappingAttrName(), 663 LoopOp::getOperandSegmentSizeAttr()}); 664 } 665 666 LogicalResult acc::LoopOp::verify() { 667 // auto, independent and seq attribute are mutually exclusive. 668 if ((auto_() && (independent() || seq())) || (independent() && seq())) { 669 return emitError("only one of " + acc::LoopOp::getAutoAttrName() + ", " + 670 acc::LoopOp::getIndependentAttrName() + ", " + 671 acc::LoopOp::getSeqAttrName() + 672 " can be present at the same time"); 673 } 674 675 // Gang, worker and vector are incompatible with seq. 676 if (seq() && exec_mapping() != OpenACCExecMapping::NONE) 677 return emitError("gang, worker or vector cannot appear with the seq attr"); 678 679 // Check non-empty body(). 680 if (region().empty()) 681 return emitError("expected non-empty body."); 682 683 return success(); 684 } 685 686 //===----------------------------------------------------------------------===// 687 // DataOp 688 //===----------------------------------------------------------------------===// 689 690 LogicalResult acc::DataOp::verify() { 691 // 2.6.5. Data Construct restriction 692 // At least one copy, copyin, copyout, create, no_create, present, deviceptr, 693 // attach, or default clause must appear on a data construct. 694 if (getOperands().empty() && !defaultAttr()) 695 return emitError("at least one operand or the default attribute " 696 "must appear on the data operation"); 697 return success(); 698 } 699 700 unsigned DataOp::getNumDataOperands() { 701 return copyOperands().size() + copyinOperands().size() + 702 copyinReadonlyOperands().size() + copyoutOperands().size() + 703 copyoutZeroOperands().size() + createOperands().size() + 704 createZeroOperands().size() + noCreateOperands().size() + 705 presentOperands().size() + deviceptrOperands().size() + 706 attachOperands().size(); 707 } 708 709 Value DataOp::getDataOperand(unsigned i) { 710 unsigned numOptional = ifCond() ? 1 : 0; 711 return getOperand(numOptional + i); 712 } 713 714 //===----------------------------------------------------------------------===// 715 // ExitDataOp 716 //===----------------------------------------------------------------------===// 717 718 LogicalResult acc::ExitDataOp::verify() { 719 // 2.6.6. Data Exit Directive restriction 720 // At least one copyout, delete, or detach clause must appear on an exit data 721 // directive. 722 if (copyoutOperands().empty() && deleteOperands().empty() && 723 detachOperands().empty()) 724 return emitError( 725 "at least one operand in copyout, delete or detach must appear on the " 726 "exit data operation"); 727 728 // The async attribute represent the async clause without value. Therefore the 729 // attribute and operand cannot appear at the same time. 730 if (asyncOperand() && async()) 731 return emitError("async attribute cannot appear with asyncOperand"); 732 733 // The wait attribute represent the wait clause without values. Therefore the 734 // attribute and operands cannot appear at the same time. 735 if (!waitOperands().empty() && wait()) 736 return emitError("wait attribute cannot appear with waitOperands"); 737 738 if (waitDevnum() && waitOperands().empty()) 739 return emitError("wait_devnum cannot appear without waitOperands"); 740 741 return success(); 742 } 743 744 unsigned ExitDataOp::getNumDataOperands() { 745 return copyoutOperands().size() + deleteOperands().size() + 746 detachOperands().size(); 747 } 748 749 Value ExitDataOp::getDataOperand(unsigned i) { 750 unsigned numOptional = ifCond() ? 1 : 0; 751 numOptional += asyncOperand() ? 1 : 0; 752 numOptional += waitDevnum() ? 1 : 0; 753 return getOperand(waitOperands().size() + numOptional + i); 754 } 755 756 void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results, 757 MLIRContext *context) { 758 results.add<RemoveConstantIfCondition<ExitDataOp>>(context); 759 } 760 761 //===----------------------------------------------------------------------===// 762 // EnterDataOp 763 //===----------------------------------------------------------------------===// 764 765 LogicalResult acc::EnterDataOp::verify() { 766 // 2.6.6. Data Enter Directive restriction 767 // At least one copyin, create, or attach clause must appear on an enter data 768 // directive. 769 if (copyinOperands().empty() && createOperands().empty() && 770 createZeroOperands().empty() && attachOperands().empty()) 771 return emitError( 772 "at least one operand in copyin, create, " 773 "create_zero or attach must appear on the enter data operation"); 774 775 // The async attribute represent the async clause without value. Therefore the 776 // attribute and operand cannot appear at the same time. 777 if (asyncOperand() && async()) 778 return emitError("async attribute cannot appear with asyncOperand"); 779 780 // The wait attribute represent the wait clause without values. Therefore the 781 // attribute and operands cannot appear at the same time. 782 if (!waitOperands().empty() && wait()) 783 return emitError("wait attribute cannot appear with waitOperands"); 784 785 if (waitDevnum() && waitOperands().empty()) 786 return emitError("wait_devnum cannot appear without waitOperands"); 787 788 return success(); 789 } 790 791 unsigned EnterDataOp::getNumDataOperands() { 792 return copyinOperands().size() + createOperands().size() + 793 createZeroOperands().size() + attachOperands().size(); 794 } 795 796 Value EnterDataOp::getDataOperand(unsigned i) { 797 unsigned numOptional = ifCond() ? 1 : 0; 798 numOptional += asyncOperand() ? 1 : 0; 799 numOptional += waitDevnum() ? 1 : 0; 800 return getOperand(waitOperands().size() + numOptional + i); 801 } 802 803 void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results, 804 MLIRContext *context) { 805 results.add<RemoveConstantIfCondition<EnterDataOp>>(context); 806 } 807 808 //===----------------------------------------------------------------------===// 809 // InitOp 810 //===----------------------------------------------------------------------===// 811 812 LogicalResult acc::InitOp::verify() { 813 Operation *currOp = *this; 814 while ((currOp = currOp->getParentOp())) 815 if (isComputeOperation(currOp)) 816 return emitOpError("cannot be nested in a compute operation"); 817 return success(); 818 } 819 820 //===----------------------------------------------------------------------===// 821 // ShutdownOp 822 //===----------------------------------------------------------------------===// 823 824 LogicalResult acc::ShutdownOp::verify() { 825 Operation *currOp = *this; 826 while ((currOp = currOp->getParentOp())) 827 if (isComputeOperation(currOp)) 828 return emitOpError("cannot be nested in a compute operation"); 829 return success(); 830 } 831 832 //===----------------------------------------------------------------------===// 833 // UpdateOp 834 //===----------------------------------------------------------------------===// 835 836 LogicalResult acc::UpdateOp::verify() { 837 // At least one of host or device should have a value. 838 if (hostOperands().empty() && deviceOperands().empty()) 839 return emitError( 840 "at least one value must be present in hostOperands or deviceOperands"); 841 842 // The async attribute represent the async clause without value. Therefore the 843 // attribute and operand cannot appear at the same time. 844 if (asyncOperand() && async()) 845 return emitError("async attribute cannot appear with asyncOperand"); 846 847 // The wait attribute represent the wait clause without values. Therefore the 848 // attribute and operands cannot appear at the same time. 849 if (!waitOperands().empty() && wait()) 850 return emitError("wait attribute cannot appear with waitOperands"); 851 852 if (waitDevnum() && waitOperands().empty()) 853 return emitError("wait_devnum cannot appear without waitOperands"); 854 855 return success(); 856 } 857 858 unsigned UpdateOp::getNumDataOperands() { 859 return hostOperands().size() + deviceOperands().size(); 860 } 861 862 Value UpdateOp::getDataOperand(unsigned i) { 863 unsigned numOptional = asyncOperand() ? 1 : 0; 864 numOptional += waitDevnum() ? 1 : 0; 865 numOptional += ifCond() ? 1 : 0; 866 return getOperand(waitOperands().size() + deviceTypeOperands().size() + 867 numOptional + i); 868 } 869 870 void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results, 871 MLIRContext *context) { 872 results.add<RemoveConstantIfCondition<UpdateOp>>(context); 873 } 874 875 //===----------------------------------------------------------------------===// 876 // WaitOp 877 //===----------------------------------------------------------------------===// 878 879 LogicalResult acc::WaitOp::verify() { 880 // The async attribute represent the async clause without value. Therefore the 881 // attribute and operand cannot appear at the same time. 882 if (asyncOperand() && async()) 883 return emitError("async attribute cannot appear with asyncOperand"); 884 885 if (waitDevnum() && waitOperands().empty()) 886 return emitError("wait_devnum cannot appear without waitOperands"); 887 888 return success(); 889 } 890 891 #define GET_OP_CLASSES 892 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc" 893 894 #define GET_ATTRDEF_CLASSES 895 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc" 896