1 //===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===// 2 // 3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 // ============================================================================= 8 9 #include "mlir/Dialect/OpenACC/OpenACC.h" 10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11 #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc" 12 #include "mlir/Dialect/StandardOps/IR/Ops.h" 13 #include "mlir/IR/Builders.h" 14 #include "mlir/IR/BuiltinTypes.h" 15 #include "mlir/IR/DialectImplementation.h" 16 #include "mlir/IR/OpImplementation.h" 17 #include "mlir/Transforms/DialectConversion.h" 18 #include "llvm/ADT/TypeSwitch.h" 19 20 using namespace mlir; 21 using namespace acc; 22 23 #include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc" 24 25 //===----------------------------------------------------------------------===// 26 // OpenACC operations 27 //===----------------------------------------------------------------------===// 28 29 void OpenACCDialect::initialize() { 30 addOperations< 31 #define GET_OP_LIST 32 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc" 33 >(); 34 addAttributes< 35 #define GET_ATTRDEF_LIST 36 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc" 37 >(); 38 } 39 40 template <typename StructureOp> 41 static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, 42 unsigned nRegions = 1) { 43 44 SmallVector<Region *, 2> regions; 45 for (unsigned i = 0; i < nRegions; ++i) 46 regions.push_back(state.addRegion()); 47 48 for (Region *region : regions) { 49 if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{})) 50 return failure(); 51 } 52 53 return success(); 54 } 55 56 static ParseResult 57 parseOperandList(OpAsmParser &parser, StringRef keyword, 58 SmallVectorImpl<OpAsmParser::OperandType> &args, 59 SmallVectorImpl<Type> &argTypes, OperationState &result) { 60 if (failed(parser.parseOptionalKeyword(keyword))) 61 return success(); 62 63 if (failed(parser.parseLParen())) 64 return failure(); 65 66 // Exit early if the list is empty. 67 if (succeeded(parser.parseOptionalRParen())) 68 return success(); 69 70 do { 71 OpAsmParser::OperandType arg; 72 Type type; 73 74 if (parser.parseRegionArgument(arg) || parser.parseColonType(type)) 75 return failure(); 76 77 args.push_back(arg); 78 argTypes.push_back(type); 79 } while (succeeded(parser.parseOptionalComma())); 80 81 if (failed(parser.parseRParen())) 82 return failure(); 83 84 return parser.resolveOperands(args, argTypes, parser.getCurrentLocation(), 85 result.operands); 86 } 87 88 static void printOperandList(Operation::operand_range operands, 89 StringRef listName, OpAsmPrinter &printer) { 90 91 if (!operands.empty()) { 92 printer << " " << listName << "("; 93 llvm::interleaveComma(operands, printer, [&](Value op) { 94 printer << op << ": " << op.getType(); 95 }); 96 printer << ")"; 97 } 98 } 99 100 static ParseResult parseOptionalOperand(OpAsmParser &parser, StringRef keyword, 101 OpAsmParser::OperandType &operand, 102 Type type, bool &hasOptional, 103 OperationState &result) { 104 hasOptional = false; 105 if (succeeded(parser.parseOptionalKeyword(keyword))) { 106 hasOptional = true; 107 if (parser.parseLParen() || parser.parseOperand(operand) || 108 parser.resolveOperand(operand, type, result.operands) || 109 parser.parseRParen()) 110 return failure(); 111 } 112 return success(); 113 } 114 115 static ParseResult parseOperandAndType(OpAsmParser &parser, 116 OperationState &result) { 117 OpAsmParser::OperandType operand; 118 Type type; 119 if (parser.parseOperand(operand) || parser.parseColonType(type) || 120 parser.resolveOperand(operand, type, result.operands)) 121 return failure(); 122 return success(); 123 } 124 125 /// Parse optional operand and its type wrapped in parenthesis prefixed with 126 /// a keyword. 127 /// Example: 128 /// keyword `(` %vectorLength: i64 `)` 129 static OptionalParseResult parseOptionalOperandAndType(OpAsmParser &parser, 130 StringRef keyword, 131 OperationState &result) { 132 OpAsmParser::OperandType operand; 133 if (succeeded(parser.parseOptionalKeyword(keyword))) { 134 return failure(parser.parseLParen() || 135 parseOperandAndType(parser, result) || parser.parseRParen()); 136 } 137 return llvm::None; 138 } 139 140 /// Parse optional operand and its type wrapped in parenthesis. 141 /// Example: 142 /// `(` %vectorLength: i64 `)` 143 static OptionalParseResult parseOptionalOperandAndType(OpAsmParser &parser, 144 OperationState &result) { 145 if (succeeded(parser.parseOptionalLParen())) { 146 return failure(parseOperandAndType(parser, result) || parser.parseRParen()); 147 } 148 return llvm::None; 149 } 150 151 /// Parse optional operand with its type prefixed with prefixKeyword `=`. 152 /// Example: 153 /// num=%gangNum: i32 154 static OptionalParseResult parserOptionalOperandAndTypeWithPrefix( 155 OpAsmParser &parser, OperationState &result, StringRef prefixKeyword) { 156 if (succeeded(parser.parseOptionalKeyword(prefixKeyword))) { 157 parser.parseEqual(); 158 return parseOperandAndType(parser, result); 159 } 160 return llvm::None; 161 } 162 163 static bool isComputeOperation(Operation *op) { 164 return isa<acc::ParallelOp>(op) || isa<acc::LoopOp>(op); 165 } 166 167 namespace { 168 /// Pattern to remove operation without region that have constant false `ifCond` 169 /// and remove the condition from the operation if the `ifCond` is a true 170 /// constant. 171 template <typename OpTy> 172 struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> { 173 using OpRewritePattern<OpTy>::OpRewritePattern; 174 175 LogicalResult matchAndRewrite(OpTy op, 176 PatternRewriter &rewriter) const override { 177 // Early return if there is no condition. 178 if (!op.ifCond()) 179 return success(); 180 181 auto constOp = op.ifCond().template getDefiningOp<arith::ConstantOp>(); 182 if (constOp && constOp.getValue().template cast<IntegerAttr>().getInt()) 183 rewriter.updateRootInPlace(op, [&]() { op.ifCondMutable().erase(0); }); 184 else if (constOp) 185 rewriter.eraseOp(op); 186 187 return success(); 188 } 189 }; 190 } // namespace 191 192 //===----------------------------------------------------------------------===// 193 // ParallelOp 194 //===----------------------------------------------------------------------===// 195 196 /// Parse acc.parallel operation 197 /// operation := `acc.parallel` `async` `(` index `)`? 198 /// `wait` `(` index-list `)`? 199 /// `num_gangs` `(` value `)`? 200 /// `num_workers` `(` value `)`? 201 /// `vector_length` `(` value `)`? 202 /// `if` `(` value `)`? 203 /// `self` `(` value `)`? 204 /// `reduction` `(` value-list `)`? 205 /// `copy` `(` value-list `)`? 206 /// `copyin` `(` value-list `)`? 207 /// `copyin_readonly` `(` value-list `)`? 208 /// `copyout` `(` value-list `)`? 209 /// `copyout_zero` `(` value-list `)`? 210 /// `create` `(` value-list `)`? 211 /// `create_zero` `(` value-list `)`? 212 /// `no_create` `(` value-list `)`? 213 /// `present` `(` value-list `)`? 214 /// `deviceptr` `(` value-list `)`? 215 /// `attach` `(` value-list `)`? 216 /// `private` `(` value-list `)`? 217 /// `firstprivate` `(` value-list `)`? 218 /// region attr-dict? 219 static ParseResult parseParallelOp(OpAsmParser &parser, 220 OperationState &result) { 221 Builder &builder = parser.getBuilder(); 222 SmallVector<OpAsmParser::OperandType, 8> privateOperands, 223 firstprivateOperands, copyOperands, copyinOperands, 224 copyinReadonlyOperands, copyoutOperands, copyoutZeroOperands, 225 createOperands, createZeroOperands, noCreateOperands, presentOperands, 226 devicePtrOperands, attachOperands, waitOperands, reductionOperands; 227 SmallVector<Type, 8> waitOperandTypes, reductionOperandTypes, 228 copyOperandTypes, copyinOperandTypes, copyinReadonlyOperandTypes, 229 copyoutOperandTypes, copyoutZeroOperandTypes, createOperandTypes, 230 createZeroOperandTypes, noCreateOperandTypes, presentOperandTypes, 231 deviceptrOperandTypes, attachOperandTypes, privateOperandTypes, 232 firstprivateOperandTypes; 233 234 SmallVector<Type, 8> operandTypes; 235 OpAsmParser::OperandType ifCond, selfCond; 236 bool hasIfCond = false, hasSelfCond = false; 237 OptionalParseResult async, numGangs, numWorkers, vectorLength; 238 Type i1Type = builder.getI1Type(); 239 240 // async()? 241 async = parseOptionalOperandAndType(parser, ParallelOp::getAsyncKeyword(), 242 result); 243 if (async.hasValue() && failed(*async)) 244 return failure(); 245 246 // wait()? 247 if (failed(parseOperandList(parser, ParallelOp::getWaitKeyword(), 248 waitOperands, waitOperandTypes, result))) 249 return failure(); 250 251 // num_gangs(value)? 252 numGangs = parseOptionalOperandAndType( 253 parser, ParallelOp::getNumGangsKeyword(), result); 254 if (numGangs.hasValue() && failed(*numGangs)) 255 return failure(); 256 257 // num_workers(value)? 258 numWorkers = parseOptionalOperandAndType( 259 parser, ParallelOp::getNumWorkersKeyword(), result); 260 if (numWorkers.hasValue() && failed(*numWorkers)) 261 return failure(); 262 263 // vector_length(value)? 264 vectorLength = parseOptionalOperandAndType( 265 parser, ParallelOp::getVectorLengthKeyword(), result); 266 if (vectorLength.hasValue() && failed(*vectorLength)) 267 return failure(); 268 269 // if()? 270 if (failed(parseOptionalOperand(parser, ParallelOp::getIfKeyword(), ifCond, 271 i1Type, hasIfCond, result))) 272 return failure(); 273 274 // self()? 275 if (failed(parseOptionalOperand(parser, ParallelOp::getSelfKeyword(), 276 selfCond, i1Type, hasSelfCond, result))) 277 return failure(); 278 279 // reduction()? 280 if (failed(parseOperandList(parser, ParallelOp::getReductionKeyword(), 281 reductionOperands, reductionOperandTypes, 282 result))) 283 return failure(); 284 285 // copy()? 286 if (failed(parseOperandList(parser, ParallelOp::getCopyKeyword(), 287 copyOperands, copyOperandTypes, result))) 288 return failure(); 289 290 // copyin()? 291 if (failed(parseOperandList(parser, ParallelOp::getCopyinKeyword(), 292 copyinOperands, copyinOperandTypes, result))) 293 return failure(); 294 295 // copyin_readonly()? 296 if (failed(parseOperandList(parser, ParallelOp::getCopyinReadonlyKeyword(), 297 copyinReadonlyOperands, 298 copyinReadonlyOperandTypes, result))) 299 return failure(); 300 301 // copyout()? 302 if (failed(parseOperandList(parser, ParallelOp::getCopyoutKeyword(), 303 copyoutOperands, copyoutOperandTypes, result))) 304 return failure(); 305 306 // copyout_zero()? 307 if (failed(parseOperandList(parser, ParallelOp::getCopyoutZeroKeyword(), 308 copyoutZeroOperands, copyoutZeroOperandTypes, 309 result))) 310 return failure(); 311 312 // create()? 313 if (failed(parseOperandList(parser, ParallelOp::getCreateKeyword(), 314 createOperands, createOperandTypes, result))) 315 return failure(); 316 317 // create_zero()? 318 if (failed(parseOperandList(parser, ParallelOp::getCreateZeroKeyword(), 319 createZeroOperands, createZeroOperandTypes, 320 result))) 321 return failure(); 322 323 // no_create()? 324 if (failed(parseOperandList(parser, ParallelOp::getNoCreateKeyword(), 325 noCreateOperands, noCreateOperandTypes, result))) 326 return failure(); 327 328 // present()? 329 if (failed(parseOperandList(parser, ParallelOp::getPresentKeyword(), 330 presentOperands, presentOperandTypes, result))) 331 return failure(); 332 333 // deviceptr()? 334 if (failed(parseOperandList(parser, ParallelOp::getDevicePtrKeyword(), 335 devicePtrOperands, deviceptrOperandTypes, 336 result))) 337 return failure(); 338 339 // attach()? 340 if (failed(parseOperandList(parser, ParallelOp::getAttachKeyword(), 341 attachOperands, attachOperandTypes, result))) 342 return failure(); 343 344 // private()? 345 if (failed(parseOperandList(parser, ParallelOp::getPrivateKeyword(), 346 privateOperands, privateOperandTypes, result))) 347 return failure(); 348 349 // firstprivate()? 350 if (failed(parseOperandList(parser, ParallelOp::getFirstPrivateKeyword(), 351 firstprivateOperands, firstprivateOperandTypes, 352 result))) 353 return failure(); 354 355 // Parallel op region 356 if (failed(parseRegions<ParallelOp>(parser, result))) 357 return failure(); 358 359 result.addAttribute( 360 ParallelOp::getOperandSegmentSizeAttr(), 361 builder.getI32VectorAttr( 362 {static_cast<int32_t>(async.hasValue() ? 1 : 0), 363 static_cast<int32_t>(waitOperands.size()), 364 static_cast<int32_t>(numGangs.hasValue() ? 1 : 0), 365 static_cast<int32_t>(numWorkers.hasValue() ? 1 : 0), 366 static_cast<int32_t>(vectorLength.hasValue() ? 1 : 0), 367 static_cast<int32_t>(hasIfCond ? 1 : 0), 368 static_cast<int32_t>(hasSelfCond ? 1 : 0), 369 static_cast<int32_t>(reductionOperands.size()), 370 static_cast<int32_t>(copyOperands.size()), 371 static_cast<int32_t>(copyinOperands.size()), 372 static_cast<int32_t>(copyinReadonlyOperands.size()), 373 static_cast<int32_t>(copyoutOperands.size()), 374 static_cast<int32_t>(copyoutZeroOperands.size()), 375 static_cast<int32_t>(createOperands.size()), 376 static_cast<int32_t>(createZeroOperands.size()), 377 static_cast<int32_t>(noCreateOperands.size()), 378 static_cast<int32_t>(presentOperands.size()), 379 static_cast<int32_t>(devicePtrOperands.size()), 380 static_cast<int32_t>(attachOperands.size()), 381 static_cast<int32_t>(privateOperands.size()), 382 static_cast<int32_t>(firstprivateOperands.size())})); 383 384 // Additional attributes 385 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) 386 return failure(); 387 388 return success(); 389 } 390 391 static void print(OpAsmPrinter &printer, ParallelOp &op) { 392 // async()? 393 if (Value async = op.async()) 394 printer << " " << ParallelOp::getAsyncKeyword() << "(" << async << ": " 395 << async.getType() << ")"; 396 397 // wait()? 398 printOperandList(op.waitOperands(), ParallelOp::getWaitKeyword(), printer); 399 400 // num_gangs()? 401 if (Value numGangs = op.numGangs()) 402 printer << " " << ParallelOp::getNumGangsKeyword() << "(" << numGangs 403 << ": " << numGangs.getType() << ")"; 404 405 // num_workers()? 406 if (Value numWorkers = op.numWorkers()) 407 printer << " " << ParallelOp::getNumWorkersKeyword() << "(" << numWorkers 408 << ": " << numWorkers.getType() << ")"; 409 410 // vector_length()? 411 if (Value vectorLength = op.vectorLength()) 412 printer << " " << ParallelOp::getVectorLengthKeyword() << "(" 413 << vectorLength << ": " << vectorLength.getType() << ")"; 414 415 // if()? 416 if (Value ifCond = op.ifCond()) 417 printer << " " << ParallelOp::getIfKeyword() << "(" << ifCond << ")"; 418 419 // self()? 420 if (Value selfCond = op.selfCond()) 421 printer << " " << ParallelOp::getSelfKeyword() << "(" << selfCond << ")"; 422 423 // reduction()? 424 printOperandList(op.reductionOperands(), ParallelOp::getReductionKeyword(), 425 printer); 426 427 // copy()? 428 printOperandList(op.copyOperands(), ParallelOp::getCopyKeyword(), printer); 429 430 // copyin()? 431 printOperandList(op.copyinOperands(), ParallelOp::getCopyinKeyword(), 432 printer); 433 434 // copyin_readonly()? 435 printOperandList(op.copyinReadonlyOperands(), 436 ParallelOp::getCopyinReadonlyKeyword(), printer); 437 438 // copyout()? 439 printOperandList(op.copyoutOperands(), ParallelOp::getCopyoutKeyword(), 440 printer); 441 442 // copyout_zero()? 443 printOperandList(op.copyoutZeroOperands(), 444 ParallelOp::getCopyoutZeroKeyword(), printer); 445 446 // create()? 447 printOperandList(op.createOperands(), ParallelOp::getCreateKeyword(), 448 printer); 449 450 // create_zero()? 451 printOperandList(op.createZeroOperands(), ParallelOp::getCreateZeroKeyword(), 452 printer); 453 454 // no_create()? 455 printOperandList(op.noCreateOperands(), ParallelOp::getNoCreateKeyword(), 456 printer); 457 458 // present()? 459 printOperandList(op.presentOperands(), ParallelOp::getPresentKeyword(), 460 printer); 461 462 // deviceptr()? 463 printOperandList(op.devicePtrOperands(), ParallelOp::getDevicePtrKeyword(), 464 printer); 465 466 // attach()? 467 printOperandList(op.attachOperands(), ParallelOp::getAttachKeyword(), 468 printer); 469 470 // private()? 471 printOperandList(op.gangPrivateOperands(), ParallelOp::getPrivateKeyword(), 472 printer); 473 474 // firstprivate()? 475 printOperandList(op.gangFirstPrivateOperands(), 476 ParallelOp::getFirstPrivateKeyword(), printer); 477 478 printer << ' '; 479 printer.printRegion(op.region(), 480 /*printEntryBlockArgs=*/false, 481 /*printBlockTerminators=*/true); 482 printer.printOptionalAttrDictWithKeyword( 483 op->getAttrs(), ParallelOp::getOperandSegmentSizeAttr()); 484 } 485 486 unsigned ParallelOp::getNumDataOperands() { 487 return reductionOperands().size() + copyOperands().size() + 488 copyinOperands().size() + copyinReadonlyOperands().size() + 489 copyoutOperands().size() + copyoutZeroOperands().size() + 490 createOperands().size() + createZeroOperands().size() + 491 noCreateOperands().size() + presentOperands().size() + 492 devicePtrOperands().size() + attachOperands().size() + 493 gangPrivateOperands().size() + gangFirstPrivateOperands().size(); 494 } 495 496 Value ParallelOp::getDataOperand(unsigned i) { 497 unsigned numOptional = async() ? 1 : 0; 498 numOptional += numGangs() ? 1 : 0; 499 numOptional += numWorkers() ? 1 : 0; 500 numOptional += vectorLength() ? 1 : 0; 501 numOptional += ifCond() ? 1 : 0; 502 numOptional += selfCond() ? 1 : 0; 503 return getOperand(waitOperands().size() + numOptional + i); 504 } 505 506 //===----------------------------------------------------------------------===// 507 // LoopOp 508 //===----------------------------------------------------------------------===// 509 510 /// Parse acc.loop operation 511 /// operation := `acc.loop` 512 /// (`gang` ( `(` (`num=` value)? (`,` `static=` value `)`)? )? )? 513 /// (`vector` ( `(` value `)` )? )? (`worker` (`(` value `)`)? )? 514 /// (`vector_length` `(` value `)`)? 515 /// (`tile` `(` value-list `)`)? 516 /// (`private` `(` value-list `)`)? 517 /// (`reduction` `(` value-list `)`)? 518 /// region attr-dict? 519 static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &result) { 520 Builder &builder = parser.getBuilder(); 521 unsigned executionMapping = OpenACCExecMapping::NONE; 522 SmallVector<Type, 8> operandTypes; 523 SmallVector<OpAsmParser::OperandType, 8> privateOperands, reductionOperands; 524 SmallVector<OpAsmParser::OperandType, 8> tileOperands; 525 OptionalParseResult gangNum, gangStatic, worker, vector; 526 527 // gang? 528 if (succeeded(parser.parseOptionalKeyword(LoopOp::getGangKeyword()))) 529 executionMapping |= OpenACCExecMapping::GANG; 530 531 // optional gang operand 532 if (succeeded(parser.parseOptionalLParen())) { 533 gangNum = parserOptionalOperandAndTypeWithPrefix( 534 parser, result, LoopOp::getGangNumKeyword()); 535 if (gangNum.hasValue() && failed(*gangNum)) 536 return failure(); 537 parser.parseOptionalComma(); 538 gangStatic = parserOptionalOperandAndTypeWithPrefix( 539 parser, result, LoopOp::getGangStaticKeyword()); 540 if (gangStatic.hasValue() && failed(*gangStatic)) 541 return failure(); 542 parser.parseOptionalComma(); 543 if (failed(parser.parseRParen())) 544 return failure(); 545 } 546 547 // worker? 548 if (succeeded(parser.parseOptionalKeyword(LoopOp::getWorkerKeyword()))) 549 executionMapping |= OpenACCExecMapping::WORKER; 550 551 // optional worker operand 552 worker = parseOptionalOperandAndType(parser, result); 553 if (worker.hasValue() && failed(*worker)) 554 return failure(); 555 556 // vector? 557 if (succeeded(parser.parseOptionalKeyword(LoopOp::getVectorKeyword()))) 558 executionMapping |= OpenACCExecMapping::VECTOR; 559 560 // optional vector operand 561 vector = parseOptionalOperandAndType(parser, result); 562 if (vector.hasValue() && failed(*vector)) 563 return failure(); 564 565 // tile()? 566 if (failed(parseOperandList(parser, LoopOp::getTileKeyword(), tileOperands, 567 operandTypes, result))) 568 return failure(); 569 570 // private()? 571 if (failed(parseOperandList(parser, LoopOp::getPrivateKeyword(), 572 privateOperands, operandTypes, result))) 573 return failure(); 574 575 // reduction()? 576 if (failed(parseOperandList(parser, LoopOp::getReductionKeyword(), 577 reductionOperands, operandTypes, result))) 578 return failure(); 579 580 if (executionMapping != acc::OpenACCExecMapping::NONE) 581 result.addAttribute(LoopOp::getExecutionMappingAttrName(), 582 builder.getI64IntegerAttr(executionMapping)); 583 584 // Parse optional results in case there is a reduce. 585 if (parser.parseOptionalArrowTypeList(result.types)) 586 return failure(); 587 588 if (failed(parseRegions<LoopOp>(parser, result))) 589 return failure(); 590 591 result.addAttribute(LoopOp::getOperandSegmentSizeAttr(), 592 builder.getI32VectorAttr( 593 {static_cast<int32_t>(gangNum.hasValue() ? 1 : 0), 594 static_cast<int32_t>(gangStatic.hasValue() ? 1 : 0), 595 static_cast<int32_t>(worker.hasValue() ? 1 : 0), 596 static_cast<int32_t>(vector.hasValue() ? 1 : 0), 597 static_cast<int32_t>(tileOperands.size()), 598 static_cast<int32_t>(privateOperands.size()), 599 static_cast<int32_t>(reductionOperands.size())})); 600 601 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) 602 return failure(); 603 604 return success(); 605 } 606 607 static void print(OpAsmPrinter &printer, LoopOp &op) { 608 unsigned execMapping = op.exec_mapping(); 609 if (execMapping & OpenACCExecMapping::GANG) { 610 printer << " " << LoopOp::getGangKeyword(); 611 Value gangNum = op.gangNum(); 612 Value gangStatic = op.gangStatic(); 613 614 // Print optional gang operands 615 if (gangNum || gangStatic) { 616 printer << "("; 617 if (gangNum) { 618 printer << LoopOp::getGangNumKeyword() << "=" << gangNum << ": " 619 << gangNum.getType(); 620 if (gangStatic) 621 printer << ", "; 622 } 623 if (gangStatic) 624 printer << LoopOp::getGangStaticKeyword() << "=" << gangStatic << ": " 625 << gangStatic.getType(); 626 printer << ")"; 627 } 628 } 629 630 if (execMapping & OpenACCExecMapping::WORKER) { 631 printer << " " << LoopOp::getWorkerKeyword(); 632 633 // Print optional worker operand if present 634 if (Value workerNum = op.workerNum()) 635 printer << "(" << workerNum << ": " << workerNum.getType() << ")"; 636 } 637 638 if (execMapping & OpenACCExecMapping::VECTOR) { 639 printer << " " << LoopOp::getVectorKeyword(); 640 641 // Print optional vector operand if present 642 if (Value vectorLength = op.vectorLength()) 643 printer << "(" << vectorLength << ": " << vectorLength.getType() << ")"; 644 } 645 646 // tile()? 647 printOperandList(op.tileOperands(), LoopOp::getTileKeyword(), printer); 648 649 // private()? 650 printOperandList(op.privateOperands(), LoopOp::getPrivateKeyword(), printer); 651 652 // reduction()? 653 printOperandList(op.reductionOperands(), LoopOp::getReductionKeyword(), 654 printer); 655 656 if (op.getNumResults() > 0) 657 printer << " -> (" << op.getResultTypes() << ")"; 658 659 printer << ' '; 660 printer.printRegion(op.region(), 661 /*printEntryBlockArgs=*/false, 662 /*printBlockTerminators=*/true); 663 664 printer.printOptionalAttrDictWithKeyword( 665 op->getAttrs(), {LoopOp::getExecutionMappingAttrName(), 666 LoopOp::getOperandSegmentSizeAttr()}); 667 } 668 669 static LogicalResult verifyLoopOp(acc::LoopOp loopOp) { 670 // auto, independent and seq attribute are mutually exclusive. 671 if ((loopOp.auto_() && (loopOp.independent() || loopOp.seq())) || 672 (loopOp.independent() && loopOp.seq())) { 673 loopOp.emitError("only one of " + acc::LoopOp::getAutoAttrName() + ", " + 674 acc::LoopOp::getIndependentAttrName() + ", " + 675 acc::LoopOp::getSeqAttrName() + 676 " can be present at the same time"); 677 return failure(); 678 } 679 680 // Gang, worker and vector are incompatible with seq. 681 if (loopOp.seq() && loopOp.exec_mapping() != OpenACCExecMapping::NONE) { 682 loopOp.emitError("gang, worker or vector cannot appear with the seq attr"); 683 return failure(); 684 } 685 686 // Check non-empty body(). 687 if (loopOp.region().empty()) { 688 loopOp.emitError("expected non-empty body."); 689 return failure(); 690 } 691 692 return success(); 693 } 694 695 //===----------------------------------------------------------------------===// 696 // DataOp 697 //===----------------------------------------------------------------------===// 698 699 static LogicalResult verify(acc::DataOp dataOp) { 700 // 2.6.5. Data Construct restriction 701 // At least one copy, copyin, copyout, create, no_create, present, deviceptr, 702 // attach, or default clause must appear on a data construct. 703 if (dataOp.getOperands().empty() && !dataOp.defaultAttr()) 704 return dataOp.emitError("at least one operand or the default attribute " 705 "must appear on the data operation"); 706 return success(); 707 } 708 709 unsigned DataOp::getNumDataOperands() { 710 return copyOperands().size() + copyinOperands().size() + 711 copyinReadonlyOperands().size() + copyoutOperands().size() + 712 copyoutZeroOperands().size() + createOperands().size() + 713 createZeroOperands().size() + noCreateOperands().size() + 714 presentOperands().size() + deviceptrOperands().size() + 715 attachOperands().size(); 716 } 717 718 Value DataOp::getDataOperand(unsigned i) { 719 unsigned numOptional = ifCond() ? 1 : 0; 720 return getOperand(numOptional + i); 721 } 722 723 //===----------------------------------------------------------------------===// 724 // ExitDataOp 725 //===----------------------------------------------------------------------===// 726 727 static LogicalResult verify(acc::ExitDataOp op) { 728 // 2.6.6. Data Exit Directive restriction 729 // At least one copyout, delete, or detach clause must appear on an exit data 730 // directive. 731 if (op.copyoutOperands().empty() && op.deleteOperands().empty() && 732 op.detachOperands().empty()) 733 return op.emitError( 734 "at least one operand in copyout, delete or detach must appear on the " 735 "exit data operation"); 736 737 // The async attribute represent the async clause without value. Therefore the 738 // attribute and operand cannot appear at the same time. 739 if (op.asyncOperand() && op.async()) 740 return op.emitError("async attribute cannot appear with asyncOperand"); 741 742 // The wait attribute represent the wait clause without values. Therefore the 743 // attribute and operands cannot appear at the same time. 744 if (!op.waitOperands().empty() && op.wait()) 745 return op.emitError("wait attribute cannot appear with waitOperands"); 746 747 if (op.waitDevnum() && op.waitOperands().empty()) 748 return op.emitError("wait_devnum cannot appear without waitOperands"); 749 750 return success(); 751 } 752 753 unsigned ExitDataOp::getNumDataOperands() { 754 return copyoutOperands().size() + deleteOperands().size() + 755 detachOperands().size(); 756 } 757 758 Value ExitDataOp::getDataOperand(unsigned i) { 759 unsigned numOptional = ifCond() ? 1 : 0; 760 numOptional += asyncOperand() ? 1 : 0; 761 numOptional += waitDevnum() ? 1 : 0; 762 return getOperand(waitOperands().size() + numOptional + i); 763 } 764 765 void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results, 766 MLIRContext *context) { 767 results.add<RemoveConstantIfCondition<ExitDataOp>>(context); 768 } 769 770 //===----------------------------------------------------------------------===// 771 // EnterDataOp 772 //===----------------------------------------------------------------------===// 773 774 static LogicalResult verify(acc::EnterDataOp op) { 775 // 2.6.6. Data Enter Directive restriction 776 // At least one copyin, create, or attach clause must appear on an enter data 777 // directive. 778 if (op.copyinOperands().empty() && op.createOperands().empty() && 779 op.createZeroOperands().empty() && op.attachOperands().empty()) 780 return op.emitError( 781 "at least one operand in copyin, create, " 782 "create_zero or attach must appear on the enter data operation"); 783 784 // The async attribute represent the async clause without value. Therefore the 785 // attribute and operand cannot appear at the same time. 786 if (op.asyncOperand() && op.async()) 787 return op.emitError("async attribute cannot appear with asyncOperand"); 788 789 // The wait attribute represent the wait clause without values. Therefore the 790 // attribute and operands cannot appear at the same time. 791 if (!op.waitOperands().empty() && op.wait()) 792 return op.emitError("wait attribute cannot appear with waitOperands"); 793 794 if (op.waitDevnum() && op.waitOperands().empty()) 795 return op.emitError("wait_devnum cannot appear without waitOperands"); 796 797 return success(); 798 } 799 800 unsigned EnterDataOp::getNumDataOperands() { 801 return copyinOperands().size() + createOperands().size() + 802 createZeroOperands().size() + attachOperands().size(); 803 } 804 805 Value EnterDataOp::getDataOperand(unsigned i) { 806 unsigned numOptional = ifCond() ? 1 : 0; 807 numOptional += asyncOperand() ? 1 : 0; 808 numOptional += waitDevnum() ? 1 : 0; 809 return getOperand(waitOperands().size() + numOptional + i); 810 } 811 812 void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results, 813 MLIRContext *context) { 814 results.add<RemoveConstantIfCondition<EnterDataOp>>(context); 815 } 816 817 //===----------------------------------------------------------------------===// 818 // InitOp 819 //===----------------------------------------------------------------------===// 820 821 static LogicalResult verify(acc::InitOp initOp) { 822 Operation *currOp = initOp; 823 while ((currOp = currOp->getParentOp())) { 824 if (isComputeOperation(currOp)) 825 return initOp.emitOpError("cannot be nested in a compute operation"); 826 } 827 return success(); 828 } 829 830 //===----------------------------------------------------------------------===// 831 // ShutdownOp 832 //===----------------------------------------------------------------------===// 833 834 static LogicalResult verify(acc::ShutdownOp op) { 835 Operation *currOp = op; 836 while ((currOp = currOp->getParentOp())) { 837 if (isComputeOperation(currOp)) 838 return op.emitOpError("cannot be nested in a compute operation"); 839 } 840 return success(); 841 } 842 843 //===----------------------------------------------------------------------===// 844 // UpdateOp 845 //===----------------------------------------------------------------------===// 846 847 static LogicalResult verify(acc::UpdateOp updateOp) { 848 // At least one of host or device should have a value. 849 if (updateOp.hostOperands().empty() && updateOp.deviceOperands().empty()) 850 return updateOp.emitError("at least one value must be present in" 851 " hostOperands or deviceOperands"); 852 853 // The async attribute represent the async clause without value. Therefore the 854 // attribute and operand cannot appear at the same time. 855 if (updateOp.asyncOperand() && updateOp.async()) 856 return updateOp.emitError("async attribute cannot appear with " 857 " asyncOperand"); 858 859 // The wait attribute represent the wait clause without values. Therefore the 860 // attribute and operands cannot appear at the same time. 861 if (!updateOp.waitOperands().empty() && updateOp.wait()) 862 return updateOp.emitError("wait attribute cannot appear with waitOperands"); 863 864 if (updateOp.waitDevnum() && updateOp.waitOperands().empty()) 865 return updateOp.emitError("wait_devnum cannot appear without waitOperands"); 866 867 return success(); 868 } 869 870 unsigned UpdateOp::getNumDataOperands() { 871 return hostOperands().size() + deviceOperands().size(); 872 } 873 874 Value UpdateOp::getDataOperand(unsigned i) { 875 unsigned numOptional = asyncOperand() ? 1 : 0; 876 numOptional += waitDevnum() ? 1 : 0; 877 numOptional += ifCond() ? 1 : 0; 878 return getOperand(waitOperands().size() + deviceTypeOperands().size() + 879 numOptional + i); 880 } 881 882 void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results, 883 MLIRContext *context) { 884 results.add<RemoveConstantIfCondition<UpdateOp>>(context); 885 } 886 887 //===----------------------------------------------------------------------===// 888 // WaitOp 889 //===----------------------------------------------------------------------===// 890 891 static LogicalResult verify(acc::WaitOp waitOp) { 892 // The async attribute represent the async clause without value. Therefore the 893 // attribute and operand cannot appear at the same time. 894 if (waitOp.asyncOperand() && waitOp.async()) 895 return waitOp.emitError("async attribute cannot appear with asyncOperand"); 896 897 if (waitOp.waitDevnum() && waitOp.waitOperands().empty()) 898 return waitOp.emitError("wait_devnum cannot appear without waitOperands"); 899 900 return success(); 901 } 902 903 #define GET_OP_CLASSES 904 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc" 905 906 #define GET_ATTRDEF_CLASSES 907 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc" 908