1 //===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===// 2 // 3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 // ============================================================================= 8 9 #include "mlir/Dialect/OpenACC/OpenACC.h" 10 #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc" 11 #include "mlir/IR/Builders.h" 12 #include "mlir/IR/BuiltinTypes.h" 13 #include "mlir/IR/DialectImplementation.h" 14 #include "mlir/IR/Matchers.h" 15 #include "mlir/IR/OpImplementation.h" 16 #include "mlir/Transforms/DialectConversion.h" 17 #include "llvm/ADT/TypeSwitch.h" 18 19 using namespace mlir; 20 using namespace acc; 21 22 #include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc" 23 24 //===----------------------------------------------------------------------===// 25 // OpenACC operations 26 //===----------------------------------------------------------------------===// 27 28 void OpenACCDialect::initialize() { 29 addOperations< 30 #define GET_OP_LIST 31 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc" 32 >(); 33 addAttributes< 34 #define GET_ATTRDEF_LIST 35 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc" 36 >(); 37 } 38 39 template <typename StructureOp> 40 static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, 41 unsigned nRegions = 1) { 42 43 SmallVector<Region *, 2> regions; 44 for (unsigned i = 0; i < nRegions; ++i) 45 regions.push_back(state.addRegion()); 46 47 for (Region *region : regions) { 48 if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{})) 49 return failure(); 50 } 51 52 return success(); 53 } 54 55 static ParseResult 56 parseOperandList(OpAsmParser &parser, StringRef keyword, 57 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &args, 58 SmallVectorImpl<Type> &argTypes, OperationState &result) { 59 if (failed(parser.parseOptionalKeyword(keyword))) 60 return success(); 61 62 if (failed(parser.parseLParen())) 63 return failure(); 64 65 // Exit early if the list is empty. 66 if (succeeded(parser.parseOptionalRParen())) 67 return success(); 68 69 do { 70 OpAsmParser::UnresolvedOperand arg; 71 Type type; 72 73 if (parser.parseRegionArgument(arg) || parser.parseColonType(type)) 74 return failure(); 75 76 args.push_back(arg); 77 argTypes.push_back(type); 78 } while (succeeded(parser.parseOptionalComma())); 79 80 if (failed(parser.parseRParen())) 81 return failure(); 82 83 return parser.resolveOperands(args, argTypes, parser.getCurrentLocation(), 84 result.operands); 85 } 86 87 static void printOperandList(Operation::operand_range operands, 88 StringRef listName, OpAsmPrinter &printer) { 89 90 if (!operands.empty()) { 91 printer << " " << listName << "("; 92 llvm::interleaveComma(operands, printer, [&](Value op) { 93 printer << op << ": " << op.getType(); 94 }); 95 printer << ")"; 96 } 97 } 98 99 static ParseResult parseOptionalOperand(OpAsmParser &parser, StringRef keyword, 100 OpAsmParser::UnresolvedOperand &operand, 101 Type type, bool &hasOptional, 102 OperationState &result) { 103 hasOptional = false; 104 if (succeeded(parser.parseOptionalKeyword(keyword))) { 105 hasOptional = true; 106 if (parser.parseLParen() || parser.parseOperand(operand) || 107 parser.resolveOperand(operand, type, result.operands) || 108 parser.parseRParen()) 109 return failure(); 110 } 111 return success(); 112 } 113 114 static ParseResult parseOperandAndType(OpAsmParser &parser, 115 OperationState &result) { 116 OpAsmParser::UnresolvedOperand operand; 117 Type type; 118 if (parser.parseOperand(operand) || parser.parseColonType(type) || 119 parser.resolveOperand(operand, type, result.operands)) 120 return failure(); 121 return success(); 122 } 123 124 /// Parse optional operand and its type wrapped in parenthesis prefixed with 125 /// a keyword. 126 /// Example: 127 /// keyword `(` %vectorLength: i64 `)` 128 static OptionalParseResult parseOptionalOperandAndType(OpAsmParser &parser, 129 StringRef keyword, 130 OperationState &result) { 131 OpAsmParser::UnresolvedOperand operand; 132 if (succeeded(parser.parseOptionalKeyword(keyword))) { 133 return failure(parser.parseLParen() || 134 parseOperandAndType(parser, result) || parser.parseRParen()); 135 } 136 return llvm::None; 137 } 138 139 /// Parse optional operand and its type wrapped in parenthesis. 140 /// Example: 141 /// `(` %vectorLength: i64 `)` 142 static OptionalParseResult parseOptionalOperandAndType(OpAsmParser &parser, 143 OperationState &result) { 144 if (succeeded(parser.parseOptionalLParen())) { 145 return failure(parseOperandAndType(parser, result) || parser.parseRParen()); 146 } 147 return llvm::None; 148 } 149 150 /// Parse optional operand with its type prefixed with prefixKeyword `=`. 151 /// Example: 152 /// num=%gangNum: i32 153 static OptionalParseResult parserOptionalOperandAndTypeWithPrefix( 154 OpAsmParser &parser, OperationState &result, StringRef prefixKeyword) { 155 if (succeeded(parser.parseOptionalKeyword(prefixKeyword))) { 156 parser.parseEqual(); 157 return parseOperandAndType(parser, result); 158 } 159 return llvm::None; 160 } 161 162 static bool isComputeOperation(Operation *op) { 163 return isa<acc::ParallelOp>(op) || isa<acc::LoopOp>(op); 164 } 165 166 namespace { 167 /// Pattern to remove operation without region that have constant false `ifCond` 168 /// and remove the condition from the operation if the `ifCond` is a true 169 /// constant. 170 template <typename OpTy> 171 struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> { 172 using OpRewritePattern<OpTy>::OpRewritePattern; 173 174 LogicalResult matchAndRewrite(OpTy op, 175 PatternRewriter &rewriter) const override { 176 // Early return if there is no condition. 177 Value ifCond = op.ifCond(); 178 if (!ifCond) 179 return success(); 180 181 IntegerAttr constAttr; 182 if (matchPattern(ifCond, m_Constant(&constAttr))) { 183 if (constAttr.getInt()) 184 rewriter.updateRootInPlace(op, [&]() { op.ifCondMutable().erase(0); }); 185 else 186 rewriter.eraseOp(op); 187 } 188 189 return success(); 190 } 191 }; 192 } // namespace 193 194 //===----------------------------------------------------------------------===// 195 // ParallelOp 196 //===----------------------------------------------------------------------===// 197 198 /// Parse acc.parallel operation 199 /// operation := `acc.parallel` `async` `(` index `)`? 200 /// `wait` `(` index-list `)`? 201 /// `num_gangs` `(` value `)`? 202 /// `num_workers` `(` value `)`? 203 /// `vector_length` `(` value `)`? 204 /// `if` `(` value `)`? 205 /// `self` `(` value `)`? 206 /// `reduction` `(` value-list `)`? 207 /// `copy` `(` value-list `)`? 208 /// `copyin` `(` value-list `)`? 209 /// `copyin_readonly` `(` value-list `)`? 210 /// `copyout` `(` value-list `)`? 211 /// `copyout_zero` `(` value-list `)`? 212 /// `create` `(` value-list `)`? 213 /// `create_zero` `(` value-list `)`? 214 /// `no_create` `(` value-list `)`? 215 /// `present` `(` value-list `)`? 216 /// `deviceptr` `(` value-list `)`? 217 /// `attach` `(` value-list `)`? 218 /// `private` `(` value-list `)`? 219 /// `firstprivate` `(` value-list `)`? 220 /// region attr-dict? 221 ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) { 222 Builder &builder = parser.getBuilder(); 223 SmallVector<OpAsmParser::UnresolvedOperand, 8> privateOperands, 224 firstprivateOperands, copyOperands, copyinOperands, 225 copyinReadonlyOperands, copyoutOperands, copyoutZeroOperands, 226 createOperands, createZeroOperands, noCreateOperands, presentOperands, 227 devicePtrOperands, attachOperands, waitOperands, reductionOperands; 228 SmallVector<Type, 8> waitOperandTypes, reductionOperandTypes, 229 copyOperandTypes, copyinOperandTypes, copyinReadonlyOperandTypes, 230 copyoutOperandTypes, copyoutZeroOperandTypes, createOperandTypes, 231 createZeroOperandTypes, noCreateOperandTypes, presentOperandTypes, 232 deviceptrOperandTypes, attachOperandTypes, privateOperandTypes, 233 firstprivateOperandTypes; 234 235 SmallVector<Type, 8> operandTypes; 236 OpAsmParser::UnresolvedOperand ifCond, selfCond; 237 bool hasIfCond = false, hasSelfCond = false; 238 OptionalParseResult async, numGangs, numWorkers, vectorLength; 239 Type i1Type = builder.getI1Type(); 240 241 // async()? 242 async = parseOptionalOperandAndType(parser, ParallelOp::getAsyncKeyword(), 243 result); 244 if (async.hasValue() && failed(*async)) 245 return failure(); 246 247 // wait()? 248 if (failed(parseOperandList(parser, ParallelOp::getWaitKeyword(), 249 waitOperands, waitOperandTypes, result))) 250 return failure(); 251 252 // num_gangs(value)? 253 numGangs = parseOptionalOperandAndType( 254 parser, ParallelOp::getNumGangsKeyword(), result); 255 if (numGangs.hasValue() && failed(*numGangs)) 256 return failure(); 257 258 // num_workers(value)? 259 numWorkers = parseOptionalOperandAndType( 260 parser, ParallelOp::getNumWorkersKeyword(), result); 261 if (numWorkers.hasValue() && failed(*numWorkers)) 262 return failure(); 263 264 // vector_length(value)? 265 vectorLength = parseOptionalOperandAndType( 266 parser, ParallelOp::getVectorLengthKeyword(), result); 267 if (vectorLength.hasValue() && failed(*vectorLength)) 268 return failure(); 269 270 // if()? 271 if (failed(parseOptionalOperand(parser, ParallelOp::getIfKeyword(), ifCond, 272 i1Type, hasIfCond, result))) 273 return failure(); 274 275 // self()? 276 if (failed(parseOptionalOperand(parser, ParallelOp::getSelfKeyword(), 277 selfCond, i1Type, hasSelfCond, result))) 278 return failure(); 279 280 // reduction()? 281 if (failed(parseOperandList(parser, ParallelOp::getReductionKeyword(), 282 reductionOperands, reductionOperandTypes, 283 result))) 284 return failure(); 285 286 // copy()? 287 if (failed(parseOperandList(parser, ParallelOp::getCopyKeyword(), 288 copyOperands, copyOperandTypes, result))) 289 return failure(); 290 291 // copyin()? 292 if (failed(parseOperandList(parser, ParallelOp::getCopyinKeyword(), 293 copyinOperands, copyinOperandTypes, result))) 294 return failure(); 295 296 // copyin_readonly()? 297 if (failed(parseOperandList(parser, ParallelOp::getCopyinReadonlyKeyword(), 298 copyinReadonlyOperands, 299 copyinReadonlyOperandTypes, result))) 300 return failure(); 301 302 // copyout()? 303 if (failed(parseOperandList(parser, ParallelOp::getCopyoutKeyword(), 304 copyoutOperands, copyoutOperandTypes, result))) 305 return failure(); 306 307 // copyout_zero()? 308 if (failed(parseOperandList(parser, ParallelOp::getCopyoutZeroKeyword(), 309 copyoutZeroOperands, copyoutZeroOperandTypes, 310 result))) 311 return failure(); 312 313 // create()? 314 if (failed(parseOperandList(parser, ParallelOp::getCreateKeyword(), 315 createOperands, createOperandTypes, result))) 316 return failure(); 317 318 // create_zero()? 319 if (failed(parseOperandList(parser, ParallelOp::getCreateZeroKeyword(), 320 createZeroOperands, createZeroOperandTypes, 321 result))) 322 return failure(); 323 324 // no_create()? 325 if (failed(parseOperandList(parser, ParallelOp::getNoCreateKeyword(), 326 noCreateOperands, noCreateOperandTypes, result))) 327 return failure(); 328 329 // present()? 330 if (failed(parseOperandList(parser, ParallelOp::getPresentKeyword(), 331 presentOperands, presentOperandTypes, result))) 332 return failure(); 333 334 // deviceptr()? 335 if (failed(parseOperandList(parser, ParallelOp::getDevicePtrKeyword(), 336 devicePtrOperands, deviceptrOperandTypes, 337 result))) 338 return failure(); 339 340 // attach()? 341 if (failed(parseOperandList(parser, ParallelOp::getAttachKeyword(), 342 attachOperands, attachOperandTypes, result))) 343 return failure(); 344 345 // private()? 346 if (failed(parseOperandList(parser, ParallelOp::getPrivateKeyword(), 347 privateOperands, privateOperandTypes, result))) 348 return failure(); 349 350 // firstprivate()? 351 if (failed(parseOperandList(parser, ParallelOp::getFirstPrivateKeyword(), 352 firstprivateOperands, firstprivateOperandTypes, 353 result))) 354 return failure(); 355 356 // Parallel op region 357 if (failed(parseRegions<ParallelOp>(parser, result))) 358 return failure(); 359 360 result.addAttribute( 361 ParallelOp::getOperandSegmentSizeAttr(), 362 builder.getI32VectorAttr( 363 {static_cast<int32_t>(async.hasValue() ? 1 : 0), 364 static_cast<int32_t>(waitOperands.size()), 365 static_cast<int32_t>(numGangs.hasValue() ? 1 : 0), 366 static_cast<int32_t>(numWorkers.hasValue() ? 1 : 0), 367 static_cast<int32_t>(vectorLength.hasValue() ? 1 : 0), 368 static_cast<int32_t>(hasIfCond ? 1 : 0), 369 static_cast<int32_t>(hasSelfCond ? 1 : 0), 370 static_cast<int32_t>(reductionOperands.size()), 371 static_cast<int32_t>(copyOperands.size()), 372 static_cast<int32_t>(copyinOperands.size()), 373 static_cast<int32_t>(copyinReadonlyOperands.size()), 374 static_cast<int32_t>(copyoutOperands.size()), 375 static_cast<int32_t>(copyoutZeroOperands.size()), 376 static_cast<int32_t>(createOperands.size()), 377 static_cast<int32_t>(createZeroOperands.size()), 378 static_cast<int32_t>(noCreateOperands.size()), 379 static_cast<int32_t>(presentOperands.size()), 380 static_cast<int32_t>(devicePtrOperands.size()), 381 static_cast<int32_t>(attachOperands.size()), 382 static_cast<int32_t>(privateOperands.size()), 383 static_cast<int32_t>(firstprivateOperands.size())})); 384 385 // Additional attributes 386 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) 387 return failure(); 388 389 return success(); 390 } 391 392 void ParallelOp::print(OpAsmPrinter &printer) { 393 // async()? 394 if (Value async = this->async()) 395 printer << " " << ParallelOp::getAsyncKeyword() << "(" << async << ": " 396 << async.getType() << ")"; 397 398 // wait()? 399 printOperandList(waitOperands(), ParallelOp::getWaitKeyword(), printer); 400 401 // num_gangs()? 402 if (Value numGangs = this->numGangs()) 403 printer << " " << ParallelOp::getNumGangsKeyword() << "(" << numGangs 404 << ": " << numGangs.getType() << ")"; 405 406 // num_workers()? 407 if (Value numWorkers = this->numWorkers()) 408 printer << " " << ParallelOp::getNumWorkersKeyword() << "(" << numWorkers 409 << ": " << numWorkers.getType() << ")"; 410 411 // vector_length()? 412 if (Value vectorLength = this->vectorLength()) 413 printer << " " << ParallelOp::getVectorLengthKeyword() << "(" 414 << vectorLength << ": " << vectorLength.getType() << ")"; 415 416 // if()? 417 if (Value ifCond = this->ifCond()) 418 printer << " " << ParallelOp::getIfKeyword() << "(" << ifCond << ")"; 419 420 // self()? 421 if (Value selfCond = this->selfCond()) 422 printer << " " << ParallelOp::getSelfKeyword() << "(" << selfCond << ")"; 423 424 // reduction()? 425 printOperandList(reductionOperands(), ParallelOp::getReductionKeyword(), 426 printer); 427 428 // copy()? 429 printOperandList(copyOperands(), ParallelOp::getCopyKeyword(), printer); 430 431 // copyin()? 432 printOperandList(copyinOperands(), ParallelOp::getCopyinKeyword(), printer); 433 434 // copyin_readonly()? 435 printOperandList(copyinReadonlyOperands(), 436 ParallelOp::getCopyinReadonlyKeyword(), printer); 437 438 // copyout()? 439 printOperandList(copyoutOperands(), ParallelOp::getCopyoutKeyword(), printer); 440 441 // copyout_zero()? 442 printOperandList(copyoutZeroOperands(), ParallelOp::getCopyoutZeroKeyword(), 443 printer); 444 445 // create()? 446 printOperandList(createOperands(), ParallelOp::getCreateKeyword(), printer); 447 448 // create_zero()? 449 printOperandList(createZeroOperands(), ParallelOp::getCreateZeroKeyword(), 450 printer); 451 452 // no_create()? 453 printOperandList(noCreateOperands(), ParallelOp::getNoCreateKeyword(), 454 printer); 455 456 // present()? 457 printOperandList(presentOperands(), ParallelOp::getPresentKeyword(), printer); 458 459 // deviceptr()? 460 printOperandList(devicePtrOperands(), ParallelOp::getDevicePtrKeyword(), 461 printer); 462 463 // attach()? 464 printOperandList(attachOperands(), ParallelOp::getAttachKeyword(), printer); 465 466 // private()? 467 printOperandList(gangPrivateOperands(), ParallelOp::getPrivateKeyword(), 468 printer); 469 470 // firstprivate()? 471 printOperandList(gangFirstPrivateOperands(), 472 ParallelOp::getFirstPrivateKeyword(), printer); 473 474 printer << ' '; 475 printer.printRegion(region(), 476 /*printEntryBlockArgs=*/false, 477 /*printBlockTerminators=*/true); 478 printer.printOptionalAttrDictWithKeyword( 479 (*this)->getAttrs(), ParallelOp::getOperandSegmentSizeAttr()); 480 } 481 482 unsigned ParallelOp::getNumDataOperands() { 483 return reductionOperands().size() + copyOperands().size() + 484 copyinOperands().size() + copyinReadonlyOperands().size() + 485 copyoutOperands().size() + copyoutZeroOperands().size() + 486 createOperands().size() + createZeroOperands().size() + 487 noCreateOperands().size() + presentOperands().size() + 488 devicePtrOperands().size() + attachOperands().size() + 489 gangPrivateOperands().size() + gangFirstPrivateOperands().size(); 490 } 491 492 Value ParallelOp::getDataOperand(unsigned i) { 493 unsigned numOptional = async() ? 1 : 0; 494 numOptional += numGangs() ? 1 : 0; 495 numOptional += numWorkers() ? 1 : 0; 496 numOptional += vectorLength() ? 1 : 0; 497 numOptional += ifCond() ? 1 : 0; 498 numOptional += selfCond() ? 1 : 0; 499 return getOperand(waitOperands().size() + numOptional + i); 500 } 501 502 //===----------------------------------------------------------------------===// 503 // LoopOp 504 //===----------------------------------------------------------------------===// 505 506 /// Parse acc.loop operation 507 /// operation := `acc.loop` 508 /// (`gang` ( `(` (`num=` value)? (`,` `static=` value `)`)? )? )? 509 /// (`vector` ( `(` value `)` )? )? (`worker` (`(` value `)`)? )? 510 /// (`vector_length` `(` value `)`)? 511 /// (`tile` `(` value-list `)`)? 512 /// (`private` `(` value-list `)`)? 513 /// (`reduction` `(` value-list `)`)? 514 /// region attr-dict? 515 ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) { 516 Builder &builder = parser.getBuilder(); 517 unsigned executionMapping = OpenACCExecMapping::NONE; 518 SmallVector<Type, 8> operandTypes; 519 SmallVector<OpAsmParser::UnresolvedOperand, 8> privateOperands, 520 reductionOperands; 521 SmallVector<OpAsmParser::UnresolvedOperand, 8> tileOperands; 522 OptionalParseResult gangNum, gangStatic, worker, vector; 523 524 // gang? 525 if (succeeded(parser.parseOptionalKeyword(LoopOp::getGangKeyword()))) 526 executionMapping |= OpenACCExecMapping::GANG; 527 528 // optional gang operand 529 if (succeeded(parser.parseOptionalLParen())) { 530 gangNum = parserOptionalOperandAndTypeWithPrefix( 531 parser, result, LoopOp::getGangNumKeyword()); 532 if (gangNum.hasValue() && failed(*gangNum)) 533 return failure(); 534 parser.parseOptionalComma(); 535 gangStatic = parserOptionalOperandAndTypeWithPrefix( 536 parser, result, LoopOp::getGangStaticKeyword()); 537 if (gangStatic.hasValue() && failed(*gangStatic)) 538 return failure(); 539 parser.parseOptionalComma(); 540 if (failed(parser.parseRParen())) 541 return failure(); 542 } 543 544 // worker? 545 if (succeeded(parser.parseOptionalKeyword(LoopOp::getWorkerKeyword()))) 546 executionMapping |= OpenACCExecMapping::WORKER; 547 548 // optional worker operand 549 worker = parseOptionalOperandAndType(parser, result); 550 if (worker.hasValue() && failed(*worker)) 551 return failure(); 552 553 // vector? 554 if (succeeded(parser.parseOptionalKeyword(LoopOp::getVectorKeyword()))) 555 executionMapping |= OpenACCExecMapping::VECTOR; 556 557 // optional vector operand 558 vector = parseOptionalOperandAndType(parser, result); 559 if (vector.hasValue() && failed(*vector)) 560 return failure(); 561 562 // tile()? 563 if (failed(parseOperandList(parser, LoopOp::getTileKeyword(), tileOperands, 564 operandTypes, result))) 565 return failure(); 566 567 // private()? 568 if (failed(parseOperandList(parser, LoopOp::getPrivateKeyword(), 569 privateOperands, operandTypes, result))) 570 return failure(); 571 572 // reduction()? 573 if (failed(parseOperandList(parser, LoopOp::getReductionKeyword(), 574 reductionOperands, operandTypes, result))) 575 return failure(); 576 577 if (executionMapping != acc::OpenACCExecMapping::NONE) 578 result.addAttribute(LoopOp::getExecutionMappingAttrName(), 579 builder.getI64IntegerAttr(executionMapping)); 580 581 // Parse optional results in case there is a reduce. 582 if (parser.parseOptionalArrowTypeList(result.types)) 583 return failure(); 584 585 if (failed(parseRegions<LoopOp>(parser, result))) 586 return failure(); 587 588 result.addAttribute(LoopOp::getOperandSegmentSizeAttr(), 589 builder.getI32VectorAttr( 590 {static_cast<int32_t>(gangNum.hasValue() ? 1 : 0), 591 static_cast<int32_t>(gangStatic.hasValue() ? 1 : 0), 592 static_cast<int32_t>(worker.hasValue() ? 1 : 0), 593 static_cast<int32_t>(vector.hasValue() ? 1 : 0), 594 static_cast<int32_t>(tileOperands.size()), 595 static_cast<int32_t>(privateOperands.size()), 596 static_cast<int32_t>(reductionOperands.size())})); 597 598 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) 599 return failure(); 600 601 return success(); 602 } 603 604 void LoopOp::print(OpAsmPrinter &printer) { 605 unsigned execMapping = exec_mapping(); 606 if (execMapping & OpenACCExecMapping::GANG) { 607 printer << " " << LoopOp::getGangKeyword(); 608 Value gangNum = this->gangNum(); 609 Value gangStatic = this->gangStatic(); 610 611 // Print optional gang operands 612 if (gangNum || gangStatic) { 613 printer << "("; 614 if (gangNum) { 615 printer << LoopOp::getGangNumKeyword() << "=" << gangNum << ": " 616 << gangNum.getType(); 617 if (gangStatic) 618 printer << ", "; 619 } 620 if (gangStatic) 621 printer << LoopOp::getGangStaticKeyword() << "=" << gangStatic << ": " 622 << gangStatic.getType(); 623 printer << ")"; 624 } 625 } 626 627 if (execMapping & OpenACCExecMapping::WORKER) { 628 printer << " " << LoopOp::getWorkerKeyword(); 629 630 // Print optional worker operand if present 631 if (Value workerNum = this->workerNum()) 632 printer << "(" << workerNum << ": " << workerNum.getType() << ")"; 633 } 634 635 if (execMapping & OpenACCExecMapping::VECTOR) { 636 printer << " " << LoopOp::getVectorKeyword(); 637 638 // Print optional vector operand if present 639 if (Value vectorLength = this->vectorLength()) 640 printer << "(" << vectorLength << ": " << vectorLength.getType() << ")"; 641 } 642 643 // tile()? 644 printOperandList(tileOperands(), LoopOp::getTileKeyword(), printer); 645 646 // private()? 647 printOperandList(privateOperands(), LoopOp::getPrivateKeyword(), printer); 648 649 // reduction()? 650 printOperandList(reductionOperands(), LoopOp::getReductionKeyword(), printer); 651 652 if (getNumResults() > 0) 653 printer << " -> (" << getResultTypes() << ")"; 654 655 printer << ' '; 656 printer.printRegion(region(), 657 /*printEntryBlockArgs=*/false, 658 /*printBlockTerminators=*/true); 659 660 printer.printOptionalAttrDictWithKeyword( 661 (*this)->getAttrs(), {LoopOp::getExecutionMappingAttrName(), 662 LoopOp::getOperandSegmentSizeAttr()}); 663 } 664 665 LogicalResult acc::LoopOp::verify() { 666 // auto, independent and seq attribute are mutually exclusive. 667 if ((auto_() && (independent() || seq())) || (independent() && seq())) { 668 return emitError("only one of " + acc::LoopOp::getAutoAttrName() + ", " + 669 acc::LoopOp::getIndependentAttrName() + ", " + 670 acc::LoopOp::getSeqAttrName() + 671 " can be present at the same time"); 672 } 673 674 // Gang, worker and vector are incompatible with seq. 675 if (seq() && exec_mapping() != OpenACCExecMapping::NONE) 676 return emitError("gang, worker or vector cannot appear with the seq attr"); 677 678 // Check non-empty body(). 679 if (region().empty()) 680 return emitError("expected non-empty body."); 681 682 return success(); 683 } 684 685 //===----------------------------------------------------------------------===// 686 // DataOp 687 //===----------------------------------------------------------------------===// 688 689 LogicalResult acc::DataOp::verify() { 690 // 2.6.5. Data Construct restriction 691 // At least one copy, copyin, copyout, create, no_create, present, deviceptr, 692 // attach, or default clause must appear on a data construct. 693 if (getOperands().empty() && !defaultAttr()) 694 return emitError("at least one operand or the default attribute " 695 "must appear on the data operation"); 696 return success(); 697 } 698 699 unsigned DataOp::getNumDataOperands() { 700 return copyOperands().size() + copyinOperands().size() + 701 copyinReadonlyOperands().size() + copyoutOperands().size() + 702 copyoutZeroOperands().size() + createOperands().size() + 703 createZeroOperands().size() + noCreateOperands().size() + 704 presentOperands().size() + deviceptrOperands().size() + 705 attachOperands().size(); 706 } 707 708 Value DataOp::getDataOperand(unsigned i) { 709 unsigned numOptional = ifCond() ? 1 : 0; 710 return getOperand(numOptional + i); 711 } 712 713 //===----------------------------------------------------------------------===// 714 // ExitDataOp 715 //===----------------------------------------------------------------------===// 716 717 LogicalResult acc::ExitDataOp::verify() { 718 // 2.6.6. Data Exit Directive restriction 719 // At least one copyout, delete, or detach clause must appear on an exit data 720 // directive. 721 if (copyoutOperands().empty() && deleteOperands().empty() && 722 detachOperands().empty()) 723 return emitError( 724 "at least one operand in copyout, delete or detach must appear on the " 725 "exit data operation"); 726 727 // The async attribute represent the async clause without value. Therefore the 728 // attribute and operand cannot appear at the same time. 729 if (asyncOperand() && async()) 730 return emitError("async attribute cannot appear with asyncOperand"); 731 732 // The wait attribute represent the wait clause without values. Therefore the 733 // attribute and operands cannot appear at the same time. 734 if (!waitOperands().empty() && wait()) 735 return emitError("wait attribute cannot appear with waitOperands"); 736 737 if (waitDevnum() && waitOperands().empty()) 738 return emitError("wait_devnum cannot appear without waitOperands"); 739 740 return success(); 741 } 742 743 unsigned ExitDataOp::getNumDataOperands() { 744 return copyoutOperands().size() + deleteOperands().size() + 745 detachOperands().size(); 746 } 747 748 Value ExitDataOp::getDataOperand(unsigned i) { 749 unsigned numOptional = ifCond() ? 1 : 0; 750 numOptional += asyncOperand() ? 1 : 0; 751 numOptional += waitDevnum() ? 1 : 0; 752 return getOperand(waitOperands().size() + numOptional + i); 753 } 754 755 void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results, 756 MLIRContext *context) { 757 results.add<RemoveConstantIfCondition<ExitDataOp>>(context); 758 } 759 760 //===----------------------------------------------------------------------===// 761 // EnterDataOp 762 //===----------------------------------------------------------------------===// 763 764 LogicalResult acc::EnterDataOp::verify() { 765 // 2.6.6. Data Enter Directive restriction 766 // At least one copyin, create, or attach clause must appear on an enter data 767 // directive. 768 if (copyinOperands().empty() && createOperands().empty() && 769 createZeroOperands().empty() && attachOperands().empty()) 770 return emitError( 771 "at least one operand in copyin, create, " 772 "create_zero or attach must appear on the enter data operation"); 773 774 // The async attribute represent the async clause without value. Therefore the 775 // attribute and operand cannot appear at the same time. 776 if (asyncOperand() && async()) 777 return emitError("async attribute cannot appear with asyncOperand"); 778 779 // The wait attribute represent the wait clause without values. Therefore the 780 // attribute and operands cannot appear at the same time. 781 if (!waitOperands().empty() && wait()) 782 return emitError("wait attribute cannot appear with waitOperands"); 783 784 if (waitDevnum() && waitOperands().empty()) 785 return emitError("wait_devnum cannot appear without waitOperands"); 786 787 return success(); 788 } 789 790 unsigned EnterDataOp::getNumDataOperands() { 791 return copyinOperands().size() + createOperands().size() + 792 createZeroOperands().size() + attachOperands().size(); 793 } 794 795 Value EnterDataOp::getDataOperand(unsigned i) { 796 unsigned numOptional = ifCond() ? 1 : 0; 797 numOptional += asyncOperand() ? 1 : 0; 798 numOptional += waitDevnum() ? 1 : 0; 799 return getOperand(waitOperands().size() + numOptional + i); 800 } 801 802 void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results, 803 MLIRContext *context) { 804 results.add<RemoveConstantIfCondition<EnterDataOp>>(context); 805 } 806 807 //===----------------------------------------------------------------------===// 808 // InitOp 809 //===----------------------------------------------------------------------===// 810 811 LogicalResult acc::InitOp::verify() { 812 Operation *currOp = *this; 813 while ((currOp = currOp->getParentOp())) 814 if (isComputeOperation(currOp)) 815 return emitOpError("cannot be nested in a compute operation"); 816 return success(); 817 } 818 819 //===----------------------------------------------------------------------===// 820 // ShutdownOp 821 //===----------------------------------------------------------------------===// 822 823 LogicalResult acc::ShutdownOp::verify() { 824 Operation *currOp = *this; 825 while ((currOp = currOp->getParentOp())) 826 if (isComputeOperation(currOp)) 827 return emitOpError("cannot be nested in a compute operation"); 828 return success(); 829 } 830 831 //===----------------------------------------------------------------------===// 832 // UpdateOp 833 //===----------------------------------------------------------------------===// 834 835 LogicalResult acc::UpdateOp::verify() { 836 // At least one of host or device should have a value. 837 if (hostOperands().empty() && deviceOperands().empty()) 838 return emitError( 839 "at least one value must be present in hostOperands or deviceOperands"); 840 841 // The async attribute represent the async clause without value. Therefore the 842 // attribute and operand cannot appear at the same time. 843 if (asyncOperand() && async()) 844 return emitError("async attribute cannot appear with asyncOperand"); 845 846 // The wait attribute represent the wait clause without values. Therefore the 847 // attribute and operands cannot appear at the same time. 848 if (!waitOperands().empty() && wait()) 849 return emitError("wait attribute cannot appear with waitOperands"); 850 851 if (waitDevnum() && waitOperands().empty()) 852 return emitError("wait_devnum cannot appear without waitOperands"); 853 854 return success(); 855 } 856 857 unsigned UpdateOp::getNumDataOperands() { 858 return hostOperands().size() + deviceOperands().size(); 859 } 860 861 Value UpdateOp::getDataOperand(unsigned i) { 862 unsigned numOptional = asyncOperand() ? 1 : 0; 863 numOptional += waitDevnum() ? 1 : 0; 864 numOptional += ifCond() ? 1 : 0; 865 return getOperand(waitOperands().size() + deviceTypeOperands().size() + 866 numOptional + i); 867 } 868 869 void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results, 870 MLIRContext *context) { 871 results.add<RemoveConstantIfCondition<UpdateOp>>(context); 872 } 873 874 //===----------------------------------------------------------------------===// 875 // WaitOp 876 //===----------------------------------------------------------------------===// 877 878 LogicalResult acc::WaitOp::verify() { 879 // The async attribute represent the async clause without value. Therefore the 880 // attribute and operand cannot appear at the same time. 881 if (asyncOperand() && async()) 882 return emitError("async attribute cannot appear with asyncOperand"); 883 884 if (waitDevnum() && waitOperands().empty()) 885 return emitError("wait_devnum cannot appear without waitOperands"); 886 887 return success(); 888 } 889 890 #define GET_OP_CLASSES 891 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc" 892 893 #define GET_ATTRDEF_CLASSES 894 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc" 895