1 //===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===// 2 // 3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 // ============================================================================= 8 9 #include "mlir/Dialect/OpenACC/OpenACC.h" 10 #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc" 11 #include "mlir/IR/Builders.h" 12 #include "mlir/IR/BuiltinTypes.h" 13 #include "mlir/IR/DialectImplementation.h" 14 #include "mlir/IR/Matchers.h" 15 #include "mlir/IR/OpImplementation.h" 16 #include "mlir/Transforms/DialectConversion.h" 17 #include "llvm/ADT/TypeSwitch.h" 18 19 using namespace mlir; 20 using namespace acc; 21 22 #include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc" 23 24 //===----------------------------------------------------------------------===// 25 // OpenACC operations 26 //===----------------------------------------------------------------------===// 27 28 void OpenACCDialect::initialize() { 29 addOperations< 30 #define GET_OP_LIST 31 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc" 32 >(); 33 addAttributes< 34 #define GET_ATTRDEF_LIST 35 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc" 36 >(); 37 } 38 39 template <typename StructureOp> 40 static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, 41 unsigned nRegions = 1) { 42 43 SmallVector<Region *, 2> regions; 44 for (unsigned i = 0; i < nRegions; ++i) 45 regions.push_back(state.addRegion()); 46 47 for (Region *region : regions) { 48 if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{})) 49 return failure(); 50 } 51 52 return success(); 53 } 54 55 static ParseResult 56 parseOperandList(OpAsmParser &parser, StringRef keyword, 57 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &args, 58 SmallVectorImpl<Type> &argTypes, OperationState &result) { 59 if (failed(parser.parseOptionalKeyword(keyword))) 60 return success(); 61 62 if (failed(parser.parseLParen())) 63 return failure(); 64 65 // Exit early if the list is empty. 66 if (succeeded(parser.parseOptionalRParen())) 67 return success(); 68 69 if (failed(parser.parseCommaSeparatedList([&]() { 70 OpAsmParser::UnresolvedOperand arg; 71 Type type; 72 73 if (parser.parseOperand(arg, /*allowResultNumber=*/false) || 74 parser.parseColonType(type)) 75 return failure(); 76 77 args.push_back(arg); 78 argTypes.push_back(type); 79 return success(); 80 })) || 81 failed(parser.parseRParen())) 82 return failure(); 83 84 return parser.resolveOperands(args, argTypes, parser.getCurrentLocation(), 85 result.operands); 86 } 87 88 static void printOperandList(Operation::operand_range operands, 89 StringRef listName, OpAsmPrinter &printer) { 90 91 if (!operands.empty()) { 92 printer << " " << listName << "("; 93 llvm::interleaveComma(operands, printer, [&](Value op) { 94 printer << op << ": " << op.getType(); 95 }); 96 printer << ")"; 97 } 98 } 99 100 static ParseResult parseOptionalOperand(OpAsmParser &parser, StringRef keyword, 101 OpAsmParser::UnresolvedOperand &operand, 102 Type type, bool &hasOptional, 103 OperationState &result) { 104 hasOptional = false; 105 if (succeeded(parser.parseOptionalKeyword(keyword))) { 106 hasOptional = true; 107 if (parser.parseLParen() || parser.parseOperand(operand) || 108 parser.resolveOperand(operand, type, result.operands) || 109 parser.parseRParen()) 110 return failure(); 111 } 112 return success(); 113 } 114 115 static ParseResult parseOperandAndType(OpAsmParser &parser, 116 OperationState &result) { 117 OpAsmParser::UnresolvedOperand operand; 118 Type type; 119 if (parser.parseOperand(operand) || parser.parseColonType(type) || 120 parser.resolveOperand(operand, type, result.operands)) 121 return failure(); 122 return success(); 123 } 124 125 /// Parse optional operand and its type wrapped in parenthesis prefixed with 126 /// a keyword. 127 /// Example: 128 /// keyword `(` %vectorLength: i64 `)` 129 static OptionalParseResult parseOptionalOperandAndType(OpAsmParser &parser, 130 StringRef keyword, 131 OperationState &result) { 132 OpAsmParser::UnresolvedOperand operand; 133 if (succeeded(parser.parseOptionalKeyword(keyword))) { 134 return failure(parser.parseLParen() || 135 parseOperandAndType(parser, result) || parser.parseRParen()); 136 } 137 return llvm::None; 138 } 139 140 /// Parse optional operand and its type wrapped in parenthesis. 141 /// Example: 142 /// `(` %vectorLength: i64 `)` 143 static OptionalParseResult parseOptionalOperandAndType(OpAsmParser &parser, 144 OperationState &result) { 145 if (succeeded(parser.parseOptionalLParen())) { 146 return failure(parseOperandAndType(parser, result) || parser.parseRParen()); 147 } 148 return llvm::None; 149 } 150 151 /// Parse optional operand with its type prefixed with prefixKeyword `=`. 152 /// Example: 153 /// num=%gangNum: i32 154 static OptionalParseResult parserOptionalOperandAndTypeWithPrefix( 155 OpAsmParser &parser, OperationState &result, StringRef prefixKeyword) { 156 if (succeeded(parser.parseOptionalKeyword(prefixKeyword))) { 157 if (parser.parseEqual() || parseOperandAndType(parser, result)) 158 return failure(); 159 return success(); 160 } 161 return llvm::None; 162 } 163 164 static bool isComputeOperation(Operation *op) { 165 return isa<acc::ParallelOp>(op) || isa<acc::LoopOp>(op); 166 } 167 168 namespace { 169 /// Pattern to remove operation without region that have constant false `ifCond` 170 /// and remove the condition from the operation if the `ifCond` is a true 171 /// constant. 172 template <typename OpTy> 173 struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> { 174 using OpRewritePattern<OpTy>::OpRewritePattern; 175 176 LogicalResult matchAndRewrite(OpTy op, 177 PatternRewriter &rewriter) const override { 178 // Early return if there is no condition. 179 Value ifCond = op.ifCond(); 180 if (!ifCond) 181 return success(); 182 183 IntegerAttr constAttr; 184 if (matchPattern(ifCond, m_Constant(&constAttr))) { 185 if (constAttr.getInt()) 186 rewriter.updateRootInPlace(op, [&]() { op.ifCondMutable().erase(0); }); 187 else 188 rewriter.eraseOp(op); 189 } 190 191 return success(); 192 } 193 }; 194 } // namespace 195 196 //===----------------------------------------------------------------------===// 197 // ParallelOp 198 //===----------------------------------------------------------------------===// 199 200 /// Parse acc.parallel operation 201 /// operation := `acc.parallel` `async` `(` index `)`? 202 /// `wait` `(` index-list `)`? 203 /// `num_gangs` `(` value `)`? 204 /// `num_workers` `(` value `)`? 205 /// `vector_length` `(` value `)`? 206 /// `if` `(` value `)`? 207 /// `self` `(` value `)`? 208 /// `reduction` `(` value-list `)`? 209 /// `copy` `(` value-list `)`? 210 /// `copyin` `(` value-list `)`? 211 /// `copyin_readonly` `(` value-list `)`? 212 /// `copyout` `(` value-list `)`? 213 /// `copyout_zero` `(` value-list `)`? 214 /// `create` `(` value-list `)`? 215 /// `create_zero` `(` value-list `)`? 216 /// `no_create` `(` value-list `)`? 217 /// `present` `(` value-list `)`? 218 /// `deviceptr` `(` value-list `)`? 219 /// `attach` `(` value-list `)`? 220 /// `private` `(` value-list `)`? 221 /// `firstprivate` `(` value-list `)`? 222 /// region attr-dict? 223 ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) { 224 Builder &builder = parser.getBuilder(); 225 SmallVector<OpAsmParser::UnresolvedOperand, 8> privateOperands, 226 firstprivateOperands, copyOperands, copyinOperands, 227 copyinReadonlyOperands, copyoutOperands, copyoutZeroOperands, 228 createOperands, createZeroOperands, noCreateOperands, presentOperands, 229 devicePtrOperands, attachOperands, waitOperands, reductionOperands; 230 SmallVector<Type, 8> waitOperandTypes, reductionOperandTypes, 231 copyOperandTypes, copyinOperandTypes, copyinReadonlyOperandTypes, 232 copyoutOperandTypes, copyoutZeroOperandTypes, createOperandTypes, 233 createZeroOperandTypes, noCreateOperandTypes, presentOperandTypes, 234 deviceptrOperandTypes, attachOperandTypes, privateOperandTypes, 235 firstprivateOperandTypes; 236 237 SmallVector<Type, 8> operandTypes; 238 OpAsmParser::UnresolvedOperand ifCond, selfCond; 239 bool hasIfCond = false, hasSelfCond = false; 240 OptionalParseResult async, numGangs, numWorkers, vectorLength; 241 Type i1Type = builder.getI1Type(); 242 243 // async()? 244 async = parseOptionalOperandAndType(parser, ParallelOp::getAsyncKeyword(), 245 result); 246 if (async.hasValue() && failed(*async)) 247 return failure(); 248 249 // wait()? 250 if (failed(parseOperandList(parser, ParallelOp::getWaitKeyword(), 251 waitOperands, waitOperandTypes, result))) 252 return failure(); 253 254 // num_gangs(value)? 255 numGangs = parseOptionalOperandAndType( 256 parser, ParallelOp::getNumGangsKeyword(), result); 257 if (numGangs.hasValue() && failed(*numGangs)) 258 return failure(); 259 260 // num_workers(value)? 261 numWorkers = parseOptionalOperandAndType( 262 parser, ParallelOp::getNumWorkersKeyword(), result); 263 if (numWorkers.hasValue() && failed(*numWorkers)) 264 return failure(); 265 266 // vector_length(value)? 267 vectorLength = parseOptionalOperandAndType( 268 parser, ParallelOp::getVectorLengthKeyword(), result); 269 if (vectorLength.hasValue() && failed(*vectorLength)) 270 return failure(); 271 272 // if()? 273 if (failed(parseOptionalOperand(parser, ParallelOp::getIfKeyword(), ifCond, 274 i1Type, hasIfCond, result))) 275 return failure(); 276 277 // self()? 278 if (failed(parseOptionalOperand(parser, ParallelOp::getSelfKeyword(), 279 selfCond, i1Type, hasSelfCond, result))) 280 return failure(); 281 282 // reduction()? 283 if (failed(parseOperandList(parser, ParallelOp::getReductionKeyword(), 284 reductionOperands, reductionOperandTypes, 285 result))) 286 return failure(); 287 288 // copy()? 289 if (failed(parseOperandList(parser, ParallelOp::getCopyKeyword(), 290 copyOperands, copyOperandTypes, result))) 291 return failure(); 292 293 // copyin()? 294 if (failed(parseOperandList(parser, ParallelOp::getCopyinKeyword(), 295 copyinOperands, copyinOperandTypes, result))) 296 return failure(); 297 298 // copyin_readonly()? 299 if (failed(parseOperandList(parser, ParallelOp::getCopyinReadonlyKeyword(), 300 copyinReadonlyOperands, 301 copyinReadonlyOperandTypes, result))) 302 return failure(); 303 304 // copyout()? 305 if (failed(parseOperandList(parser, ParallelOp::getCopyoutKeyword(), 306 copyoutOperands, copyoutOperandTypes, result))) 307 return failure(); 308 309 // copyout_zero()? 310 if (failed(parseOperandList(parser, ParallelOp::getCopyoutZeroKeyword(), 311 copyoutZeroOperands, copyoutZeroOperandTypes, 312 result))) 313 return failure(); 314 315 // create()? 316 if (failed(parseOperandList(parser, ParallelOp::getCreateKeyword(), 317 createOperands, createOperandTypes, result))) 318 return failure(); 319 320 // create_zero()? 321 if (failed(parseOperandList(parser, ParallelOp::getCreateZeroKeyword(), 322 createZeroOperands, createZeroOperandTypes, 323 result))) 324 return failure(); 325 326 // no_create()? 327 if (failed(parseOperandList(parser, ParallelOp::getNoCreateKeyword(), 328 noCreateOperands, noCreateOperandTypes, result))) 329 return failure(); 330 331 // present()? 332 if (failed(parseOperandList(parser, ParallelOp::getPresentKeyword(), 333 presentOperands, presentOperandTypes, result))) 334 return failure(); 335 336 // deviceptr()? 337 if (failed(parseOperandList(parser, ParallelOp::getDevicePtrKeyword(), 338 devicePtrOperands, deviceptrOperandTypes, 339 result))) 340 return failure(); 341 342 // attach()? 343 if (failed(parseOperandList(parser, ParallelOp::getAttachKeyword(), 344 attachOperands, attachOperandTypes, result))) 345 return failure(); 346 347 // private()? 348 if (failed(parseOperandList(parser, ParallelOp::getPrivateKeyword(), 349 privateOperands, privateOperandTypes, result))) 350 return failure(); 351 352 // firstprivate()? 353 if (failed(parseOperandList(parser, ParallelOp::getFirstPrivateKeyword(), 354 firstprivateOperands, firstprivateOperandTypes, 355 result))) 356 return failure(); 357 358 // Parallel op region 359 if (failed(parseRegions<ParallelOp>(parser, result))) 360 return failure(); 361 362 result.addAttribute( 363 ParallelOp::getOperandSegmentSizeAttr(), 364 builder.getI32VectorAttr( 365 {static_cast<int32_t>(async.hasValue() ? 1 : 0), 366 static_cast<int32_t>(waitOperands.size()), 367 static_cast<int32_t>(numGangs.hasValue() ? 1 : 0), 368 static_cast<int32_t>(numWorkers.hasValue() ? 1 : 0), 369 static_cast<int32_t>(vectorLength.hasValue() ? 1 : 0), 370 static_cast<int32_t>(hasIfCond ? 1 : 0), 371 static_cast<int32_t>(hasSelfCond ? 1 : 0), 372 static_cast<int32_t>(reductionOperands.size()), 373 static_cast<int32_t>(copyOperands.size()), 374 static_cast<int32_t>(copyinOperands.size()), 375 static_cast<int32_t>(copyinReadonlyOperands.size()), 376 static_cast<int32_t>(copyoutOperands.size()), 377 static_cast<int32_t>(copyoutZeroOperands.size()), 378 static_cast<int32_t>(createOperands.size()), 379 static_cast<int32_t>(createZeroOperands.size()), 380 static_cast<int32_t>(noCreateOperands.size()), 381 static_cast<int32_t>(presentOperands.size()), 382 static_cast<int32_t>(devicePtrOperands.size()), 383 static_cast<int32_t>(attachOperands.size()), 384 static_cast<int32_t>(privateOperands.size()), 385 static_cast<int32_t>(firstprivateOperands.size())})); 386 387 // Additional attributes 388 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) 389 return failure(); 390 391 return success(); 392 } 393 394 void ParallelOp::print(OpAsmPrinter &printer) { 395 // async()? 396 if (Value async = this->async()) 397 printer << " " << ParallelOp::getAsyncKeyword() << "(" << async << ": " 398 << async.getType() << ")"; 399 400 // wait()? 401 printOperandList(waitOperands(), ParallelOp::getWaitKeyword(), printer); 402 403 // num_gangs()? 404 if (Value numGangs = this->numGangs()) 405 printer << " " << ParallelOp::getNumGangsKeyword() << "(" << numGangs 406 << ": " << numGangs.getType() << ")"; 407 408 // num_workers()? 409 if (Value numWorkers = this->numWorkers()) 410 printer << " " << ParallelOp::getNumWorkersKeyword() << "(" << numWorkers 411 << ": " << numWorkers.getType() << ")"; 412 413 // vector_length()? 414 if (Value vectorLength = this->vectorLength()) 415 printer << " " << ParallelOp::getVectorLengthKeyword() << "(" 416 << vectorLength << ": " << vectorLength.getType() << ")"; 417 418 // if()? 419 if (Value ifCond = this->ifCond()) 420 printer << " " << ParallelOp::getIfKeyword() << "(" << ifCond << ")"; 421 422 // self()? 423 if (Value selfCond = this->selfCond()) 424 printer << " " << ParallelOp::getSelfKeyword() << "(" << selfCond << ")"; 425 426 // reduction()? 427 printOperandList(reductionOperands(), ParallelOp::getReductionKeyword(), 428 printer); 429 430 // copy()? 431 printOperandList(copyOperands(), ParallelOp::getCopyKeyword(), printer); 432 433 // copyin()? 434 printOperandList(copyinOperands(), ParallelOp::getCopyinKeyword(), printer); 435 436 // copyin_readonly()? 437 printOperandList(copyinReadonlyOperands(), 438 ParallelOp::getCopyinReadonlyKeyword(), printer); 439 440 // copyout()? 441 printOperandList(copyoutOperands(), ParallelOp::getCopyoutKeyword(), printer); 442 443 // copyout_zero()? 444 printOperandList(copyoutZeroOperands(), ParallelOp::getCopyoutZeroKeyword(), 445 printer); 446 447 // create()? 448 printOperandList(createOperands(), ParallelOp::getCreateKeyword(), printer); 449 450 // create_zero()? 451 printOperandList(createZeroOperands(), ParallelOp::getCreateZeroKeyword(), 452 printer); 453 454 // no_create()? 455 printOperandList(noCreateOperands(), ParallelOp::getNoCreateKeyword(), 456 printer); 457 458 // present()? 459 printOperandList(presentOperands(), ParallelOp::getPresentKeyword(), printer); 460 461 // deviceptr()? 462 printOperandList(devicePtrOperands(), ParallelOp::getDevicePtrKeyword(), 463 printer); 464 465 // attach()? 466 printOperandList(attachOperands(), ParallelOp::getAttachKeyword(), printer); 467 468 // private()? 469 printOperandList(gangPrivateOperands(), ParallelOp::getPrivateKeyword(), 470 printer); 471 472 // firstprivate()? 473 printOperandList(gangFirstPrivateOperands(), 474 ParallelOp::getFirstPrivateKeyword(), printer); 475 476 printer << ' '; 477 printer.printRegion(region(), 478 /*printEntryBlockArgs=*/false, 479 /*printBlockTerminators=*/true); 480 printer.printOptionalAttrDictWithKeyword( 481 (*this)->getAttrs(), ParallelOp::getOperandSegmentSizeAttr()); 482 } 483 484 unsigned ParallelOp::getNumDataOperands() { 485 return reductionOperands().size() + copyOperands().size() + 486 copyinOperands().size() + copyinReadonlyOperands().size() + 487 copyoutOperands().size() + copyoutZeroOperands().size() + 488 createOperands().size() + createZeroOperands().size() + 489 noCreateOperands().size() + presentOperands().size() + 490 devicePtrOperands().size() + attachOperands().size() + 491 gangPrivateOperands().size() + gangFirstPrivateOperands().size(); 492 } 493 494 Value ParallelOp::getDataOperand(unsigned i) { 495 unsigned numOptional = async() ? 1 : 0; 496 numOptional += numGangs() ? 1 : 0; 497 numOptional += numWorkers() ? 1 : 0; 498 numOptional += vectorLength() ? 1 : 0; 499 numOptional += ifCond() ? 1 : 0; 500 numOptional += selfCond() ? 1 : 0; 501 return getOperand(waitOperands().size() + numOptional + i); 502 } 503 504 //===----------------------------------------------------------------------===// 505 // LoopOp 506 //===----------------------------------------------------------------------===// 507 508 /// Parse acc.loop operation 509 /// operation := `acc.loop` 510 /// (`gang` ( `(` (`num=` value)? (`,` `static=` value `)`)? )? )? 511 /// (`vector` ( `(` value `)` )? )? (`worker` (`(` value `)`)? )? 512 /// (`vector_length` `(` value `)`)? 513 /// (`tile` `(` value-list `)`)? 514 /// (`private` `(` value-list `)`)? 515 /// (`reduction` `(` value-list `)`)? 516 /// region attr-dict? 517 ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) { 518 Builder &builder = parser.getBuilder(); 519 unsigned executionMapping = OpenACCExecMapping::NONE; 520 SmallVector<Type, 8> operandTypes; 521 SmallVector<OpAsmParser::UnresolvedOperand, 8> privateOperands, 522 reductionOperands; 523 SmallVector<OpAsmParser::UnresolvedOperand, 8> tileOperands; 524 OptionalParseResult gangNum, gangStatic, worker, vector; 525 526 // gang? 527 if (succeeded(parser.parseOptionalKeyword(LoopOp::getGangKeyword()))) 528 executionMapping |= OpenACCExecMapping::GANG; 529 530 // optional gang operand 531 if (succeeded(parser.parseOptionalLParen())) { 532 gangNum = parserOptionalOperandAndTypeWithPrefix( 533 parser, result, LoopOp::getGangNumKeyword()); 534 if (gangNum.hasValue() && failed(*gangNum)) 535 return failure(); 536 // FIXME: Comma should require subsequent operands. 537 (void)parser.parseOptionalComma(); 538 gangStatic = parserOptionalOperandAndTypeWithPrefix( 539 parser, result, LoopOp::getGangStaticKeyword()); 540 if (gangStatic.hasValue() && failed(*gangStatic)) 541 return failure(); 542 // FIXME: Why allow optional last commas? 543 (void)parser.parseOptionalComma(); 544 if (failed(parser.parseRParen())) 545 return failure(); 546 } 547 548 // worker? 549 if (succeeded(parser.parseOptionalKeyword(LoopOp::getWorkerKeyword()))) 550 executionMapping |= OpenACCExecMapping::WORKER; 551 552 // optional worker operand 553 worker = parseOptionalOperandAndType(parser, result); 554 if (worker.hasValue() && failed(*worker)) 555 return failure(); 556 557 // vector? 558 if (succeeded(parser.parseOptionalKeyword(LoopOp::getVectorKeyword()))) 559 executionMapping |= OpenACCExecMapping::VECTOR; 560 561 // optional vector operand 562 vector = parseOptionalOperandAndType(parser, result); 563 if (vector.hasValue() && failed(*vector)) 564 return failure(); 565 566 // tile()? 567 if (failed(parseOperandList(parser, LoopOp::getTileKeyword(), tileOperands, 568 operandTypes, result))) 569 return failure(); 570 571 // private()? 572 if (failed(parseOperandList(parser, LoopOp::getPrivateKeyword(), 573 privateOperands, operandTypes, result))) 574 return failure(); 575 576 // reduction()? 577 if (failed(parseOperandList(parser, LoopOp::getReductionKeyword(), 578 reductionOperands, operandTypes, result))) 579 return failure(); 580 581 if (executionMapping != acc::OpenACCExecMapping::NONE) 582 result.addAttribute(LoopOp::getExecutionMappingAttrName(), 583 builder.getI64IntegerAttr(executionMapping)); 584 585 // Parse optional results in case there is a reduce. 586 if (parser.parseOptionalArrowTypeList(result.types)) 587 return failure(); 588 589 if (failed(parseRegions<LoopOp>(parser, result))) 590 return failure(); 591 592 result.addAttribute(LoopOp::getOperandSegmentSizeAttr(), 593 builder.getI32VectorAttr( 594 {static_cast<int32_t>(gangNum.hasValue() ? 1 : 0), 595 static_cast<int32_t>(gangStatic.hasValue() ? 1 : 0), 596 static_cast<int32_t>(worker.hasValue() ? 1 : 0), 597 static_cast<int32_t>(vector.hasValue() ? 1 : 0), 598 static_cast<int32_t>(tileOperands.size()), 599 static_cast<int32_t>(privateOperands.size()), 600 static_cast<int32_t>(reductionOperands.size())})); 601 602 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) 603 return failure(); 604 605 return success(); 606 } 607 608 void LoopOp::print(OpAsmPrinter &printer) { 609 unsigned execMapping = exec_mapping(); 610 if (execMapping & OpenACCExecMapping::GANG) { 611 printer << " " << LoopOp::getGangKeyword(); 612 Value gangNum = this->gangNum(); 613 Value gangStatic = this->gangStatic(); 614 615 // Print optional gang operands 616 if (gangNum || gangStatic) { 617 printer << "("; 618 if (gangNum) { 619 printer << LoopOp::getGangNumKeyword() << "=" << gangNum << ": " 620 << gangNum.getType(); 621 if (gangStatic) 622 printer << ", "; 623 } 624 if (gangStatic) 625 printer << LoopOp::getGangStaticKeyword() << "=" << gangStatic << ": " 626 << gangStatic.getType(); 627 printer << ")"; 628 } 629 } 630 631 if (execMapping & OpenACCExecMapping::WORKER) { 632 printer << " " << LoopOp::getWorkerKeyword(); 633 634 // Print optional worker operand if present 635 if (Value workerNum = this->workerNum()) 636 printer << "(" << workerNum << ": " << workerNum.getType() << ")"; 637 } 638 639 if (execMapping & OpenACCExecMapping::VECTOR) { 640 printer << " " << LoopOp::getVectorKeyword(); 641 642 // Print optional vector operand if present 643 if (Value vectorLength = this->vectorLength()) 644 printer << "(" << vectorLength << ": " << vectorLength.getType() << ")"; 645 } 646 647 // tile()? 648 printOperandList(tileOperands(), LoopOp::getTileKeyword(), printer); 649 650 // private()? 651 printOperandList(privateOperands(), LoopOp::getPrivateKeyword(), printer); 652 653 // reduction()? 654 printOperandList(reductionOperands(), LoopOp::getReductionKeyword(), printer); 655 656 if (getNumResults() > 0) 657 printer << " -> (" << getResultTypes() << ")"; 658 659 printer << ' '; 660 printer.printRegion(region(), 661 /*printEntryBlockArgs=*/false, 662 /*printBlockTerminators=*/true); 663 664 printer.printOptionalAttrDictWithKeyword( 665 (*this)->getAttrs(), {LoopOp::getExecutionMappingAttrName(), 666 LoopOp::getOperandSegmentSizeAttr()}); 667 } 668 669 LogicalResult acc::LoopOp::verify() { 670 // auto, independent and seq attribute are mutually exclusive. 671 if ((auto_() && (independent() || seq())) || (independent() && seq())) { 672 return emitError("only one of " + acc::LoopOp::getAutoAttrName() + ", " + 673 acc::LoopOp::getIndependentAttrName() + ", " + 674 acc::LoopOp::getSeqAttrName() + 675 " can be present at the same time"); 676 } 677 678 // Gang, worker and vector are incompatible with seq. 679 if (seq() && exec_mapping() != OpenACCExecMapping::NONE) 680 return emitError("gang, worker or vector cannot appear with the seq attr"); 681 682 // Check non-empty body(). 683 if (region().empty()) 684 return emitError("expected non-empty body."); 685 686 return success(); 687 } 688 689 //===----------------------------------------------------------------------===// 690 // DataOp 691 //===----------------------------------------------------------------------===// 692 693 LogicalResult acc::DataOp::verify() { 694 // 2.6.5. Data Construct restriction 695 // At least one copy, copyin, copyout, create, no_create, present, deviceptr, 696 // attach, or default clause must appear on a data construct. 697 if (getOperands().empty() && !defaultAttr()) 698 return emitError("at least one operand or the default attribute " 699 "must appear on the data operation"); 700 return success(); 701 } 702 703 unsigned DataOp::getNumDataOperands() { 704 return copyOperands().size() + copyinOperands().size() + 705 copyinReadonlyOperands().size() + copyoutOperands().size() + 706 copyoutZeroOperands().size() + createOperands().size() + 707 createZeroOperands().size() + noCreateOperands().size() + 708 presentOperands().size() + deviceptrOperands().size() + 709 attachOperands().size(); 710 } 711 712 Value DataOp::getDataOperand(unsigned i) { 713 unsigned numOptional = ifCond() ? 1 : 0; 714 return getOperand(numOptional + i); 715 } 716 717 //===----------------------------------------------------------------------===// 718 // ExitDataOp 719 //===----------------------------------------------------------------------===// 720 721 LogicalResult acc::ExitDataOp::verify() { 722 // 2.6.6. Data Exit Directive restriction 723 // At least one copyout, delete, or detach clause must appear on an exit data 724 // directive. 725 if (copyoutOperands().empty() && deleteOperands().empty() && 726 detachOperands().empty()) 727 return emitError( 728 "at least one operand in copyout, delete or detach must appear on the " 729 "exit data operation"); 730 731 // The async attribute represent the async clause without value. Therefore the 732 // attribute and operand cannot appear at the same time. 733 if (asyncOperand() && async()) 734 return emitError("async attribute cannot appear with asyncOperand"); 735 736 // The wait attribute represent the wait clause without values. Therefore the 737 // attribute and operands cannot appear at the same time. 738 if (!waitOperands().empty() && wait()) 739 return emitError("wait attribute cannot appear with waitOperands"); 740 741 if (waitDevnum() && waitOperands().empty()) 742 return emitError("wait_devnum cannot appear without waitOperands"); 743 744 return success(); 745 } 746 747 unsigned ExitDataOp::getNumDataOperands() { 748 return copyoutOperands().size() + deleteOperands().size() + 749 detachOperands().size(); 750 } 751 752 Value ExitDataOp::getDataOperand(unsigned i) { 753 unsigned numOptional = ifCond() ? 1 : 0; 754 numOptional += asyncOperand() ? 1 : 0; 755 numOptional += waitDevnum() ? 1 : 0; 756 return getOperand(waitOperands().size() + numOptional + i); 757 } 758 759 void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results, 760 MLIRContext *context) { 761 results.add<RemoveConstantIfCondition<ExitDataOp>>(context); 762 } 763 764 //===----------------------------------------------------------------------===// 765 // EnterDataOp 766 //===----------------------------------------------------------------------===// 767 768 LogicalResult acc::EnterDataOp::verify() { 769 // 2.6.6. Data Enter Directive restriction 770 // At least one copyin, create, or attach clause must appear on an enter data 771 // directive. 772 if (copyinOperands().empty() && createOperands().empty() && 773 createZeroOperands().empty() && attachOperands().empty()) 774 return emitError( 775 "at least one operand in copyin, create, " 776 "create_zero or attach must appear on the enter data operation"); 777 778 // The async attribute represent the async clause without value. Therefore the 779 // attribute and operand cannot appear at the same time. 780 if (asyncOperand() && async()) 781 return emitError("async attribute cannot appear with asyncOperand"); 782 783 // The wait attribute represent the wait clause without values. Therefore the 784 // attribute and operands cannot appear at the same time. 785 if (!waitOperands().empty() && wait()) 786 return emitError("wait attribute cannot appear with waitOperands"); 787 788 if (waitDevnum() && waitOperands().empty()) 789 return emitError("wait_devnum cannot appear without waitOperands"); 790 791 return success(); 792 } 793 794 unsigned EnterDataOp::getNumDataOperands() { 795 return copyinOperands().size() + createOperands().size() + 796 createZeroOperands().size() + attachOperands().size(); 797 } 798 799 Value EnterDataOp::getDataOperand(unsigned i) { 800 unsigned numOptional = ifCond() ? 1 : 0; 801 numOptional += asyncOperand() ? 1 : 0; 802 numOptional += waitDevnum() ? 1 : 0; 803 return getOperand(waitOperands().size() + numOptional + i); 804 } 805 806 void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results, 807 MLIRContext *context) { 808 results.add<RemoveConstantIfCondition<EnterDataOp>>(context); 809 } 810 811 //===----------------------------------------------------------------------===// 812 // InitOp 813 //===----------------------------------------------------------------------===// 814 815 LogicalResult acc::InitOp::verify() { 816 Operation *currOp = *this; 817 while ((currOp = currOp->getParentOp())) 818 if (isComputeOperation(currOp)) 819 return emitOpError("cannot be nested in a compute operation"); 820 return success(); 821 } 822 823 //===----------------------------------------------------------------------===// 824 // ShutdownOp 825 //===----------------------------------------------------------------------===// 826 827 LogicalResult acc::ShutdownOp::verify() { 828 Operation *currOp = *this; 829 while ((currOp = currOp->getParentOp())) 830 if (isComputeOperation(currOp)) 831 return emitOpError("cannot be nested in a compute operation"); 832 return success(); 833 } 834 835 //===----------------------------------------------------------------------===// 836 // UpdateOp 837 //===----------------------------------------------------------------------===// 838 839 LogicalResult acc::UpdateOp::verify() { 840 // At least one of host or device should have a value. 841 if (hostOperands().empty() && deviceOperands().empty()) 842 return emitError( 843 "at least one value must be present in hostOperands or deviceOperands"); 844 845 // The async attribute represent the async clause without value. Therefore the 846 // attribute and operand cannot appear at the same time. 847 if (asyncOperand() && async()) 848 return emitError("async attribute cannot appear with asyncOperand"); 849 850 // The wait attribute represent the wait clause without values. Therefore the 851 // attribute and operands cannot appear at the same time. 852 if (!waitOperands().empty() && wait()) 853 return emitError("wait attribute cannot appear with waitOperands"); 854 855 if (waitDevnum() && waitOperands().empty()) 856 return emitError("wait_devnum cannot appear without waitOperands"); 857 858 return success(); 859 } 860 861 unsigned UpdateOp::getNumDataOperands() { 862 return hostOperands().size() + deviceOperands().size(); 863 } 864 865 Value UpdateOp::getDataOperand(unsigned i) { 866 unsigned numOptional = asyncOperand() ? 1 : 0; 867 numOptional += waitDevnum() ? 1 : 0; 868 numOptional += ifCond() ? 1 : 0; 869 return getOperand(waitOperands().size() + deviceTypeOperands().size() + 870 numOptional + i); 871 } 872 873 void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results, 874 MLIRContext *context) { 875 results.add<RemoveConstantIfCondition<UpdateOp>>(context); 876 } 877 878 //===----------------------------------------------------------------------===// 879 // WaitOp 880 //===----------------------------------------------------------------------===// 881 882 LogicalResult acc::WaitOp::verify() { 883 // The async attribute represent the async clause without value. Therefore the 884 // attribute and operand cannot appear at the same time. 885 if (asyncOperand() && async()) 886 return emitError("async attribute cannot appear with asyncOperand"); 887 888 if (waitDevnum() && waitOperands().empty()) 889 return emitError("wait_devnum cannot appear without waitOperands"); 890 891 return success(); 892 } 893 894 #define GET_OP_CLASSES 895 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc" 896 897 #define GET_ATTRDEF_CLASSES 898 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc" 899