1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the OpenMP dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 15 #include "mlir/IR/Attributes.h" 16 #include "mlir/IR/DialectImplementation.h" 17 #include "mlir/IR/OpImplementation.h" 18 #include "mlir/IR/OperationSupport.h" 19 20 #include "llvm/ADT/BitVector.h" 21 #include "llvm/ADT/SmallString.h" 22 #include "llvm/ADT/StringExtras.h" 23 #include "llvm/ADT/StringRef.h" 24 #include "llvm/ADT/StringSwitch.h" 25 #include "llvm/ADT/TypeSwitch.h" 26 #include <cstddef> 27 28 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc" 29 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc" 30 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc" 31 32 using namespace mlir; 33 using namespace mlir::omp; 34 35 namespace { 36 /// Model for pointer-like types that already provide a `getElementType` method. 37 template <typename T> 38 struct PointerLikeModel 39 : public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> { 40 Type getElementType(Type pointer) const { 41 return pointer.cast<T>().getElementType(); 42 } 43 }; 44 } // namespace 45 46 void OpenMPDialect::initialize() { 47 addOperations< 48 #define GET_OP_LIST 49 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 50 >(); 51 addAttributes< 52 #define GET_ATTRDEF_LIST 53 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" 54 >(); 55 56 LLVM::LLVMPointerType::attachInterface< 57 PointerLikeModel<LLVM::LLVMPointerType>>(*getContext()); 58 MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext()); 59 } 60 61 //===----------------------------------------------------------------------===// 62 // ParallelOp 63 //===----------------------------------------------------------------------===// 64 65 void ParallelOp::build(OpBuilder &builder, OperationState &state, 66 ArrayRef<NamedAttribute> attributes) { 67 ParallelOp::build( 68 builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr, 69 /*default_val=*/nullptr, /*private_vars=*/ValueRange(), 70 /*firstprivate_vars=*/ValueRange(), /*shared_vars=*/ValueRange(), 71 /*copyin_vars=*/ValueRange(), /*allocate_vars=*/ValueRange(), 72 /*allocators_vars=*/ValueRange(), /*proc_bind_val=*/nullptr); 73 state.addAttributes(attributes); 74 } 75 76 //===----------------------------------------------------------------------===// 77 // Parser and printer for Operand and type list 78 //===----------------------------------------------------------------------===// 79 80 /// Parse a list of operands with types. 81 /// 82 /// operand-and-type-list ::= `(` ssa-id-and-type-list `)` 83 /// ssa-id-and-type-list ::= ssa-id-and-type | 84 /// ssa-id-and-type `,` ssa-id-and-type-list 85 /// ssa-id-and-type ::= ssa-id `:` type 86 static ParseResult 87 parseOperandAndTypeList(OpAsmParser &parser, 88 SmallVectorImpl<OpAsmParser::OperandType> &operands, 89 SmallVectorImpl<Type> &types) { 90 return parser.parseCommaSeparatedList( 91 OpAsmParser::Delimiter::Paren, [&]() -> ParseResult { 92 OpAsmParser::OperandType operand; 93 Type type; 94 if (parser.parseOperand(operand) || parser.parseColonType(type)) 95 return failure(); 96 operands.push_back(operand); 97 types.push_back(type); 98 return success(); 99 }); 100 } 101 102 /// Print an operand and type list with parentheses 103 static void printOperandAndTypeList(OpAsmPrinter &p, OperandRange operands) { 104 p << "("; 105 llvm::interleaveComma( 106 operands, p, [&](const Value &v) { p << v << " : " << v.getType(); }); 107 p << ") "; 108 } 109 110 /// Print data variables corresponding to a data-sharing clause `name` 111 static void printDataVars(OpAsmPrinter &p, OperandRange operands, 112 StringRef name) { 113 if (!operands.empty()) { 114 p << name; 115 printOperandAndTypeList(p, operands); 116 } 117 } 118 119 //===----------------------------------------------------------------------===// 120 // Parser and printer for Allocate Clause 121 //===----------------------------------------------------------------------===// 122 123 /// Parse an allocate clause with allocators and a list of operands with types. 124 /// 125 /// allocate ::= `allocate` `(` allocate-operand-list `)` 126 /// allocate-operand-list :: = allocate-operand | 127 /// allocator-operand `,` allocate-operand-list 128 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type 129 /// ssa-id-and-type ::= ssa-id `:` type 130 static ParseResult parseAllocateAndAllocator( 131 OpAsmParser &parser, 132 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate, 133 SmallVectorImpl<Type> &typesAllocate, 134 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator, 135 SmallVectorImpl<Type> &typesAllocator) { 136 137 return parser.parseCommaSeparatedList( 138 OpAsmParser::Delimiter::Paren, [&]() -> ParseResult { 139 OpAsmParser::OperandType operand; 140 Type type; 141 if (parser.parseOperand(operand) || parser.parseColonType(type)) 142 return failure(); 143 operandsAllocator.push_back(operand); 144 typesAllocator.push_back(type); 145 if (parser.parseArrow()) 146 return failure(); 147 if (parser.parseOperand(operand) || parser.parseColonType(type)) 148 return failure(); 149 150 operandsAllocate.push_back(operand); 151 typesAllocate.push_back(type); 152 return success(); 153 }); 154 } 155 156 /// Print allocate clause 157 static void printAllocateAndAllocator(OpAsmPrinter &p, 158 OperandRange varsAllocate, 159 OperandRange varsAllocator) { 160 p << "allocate("; 161 for (unsigned i = 0; i < varsAllocate.size(); ++i) { 162 std::string separator = i == varsAllocate.size() - 1 ? ") " : ", "; 163 p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> "; 164 p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator; 165 } 166 } 167 168 static LogicalResult verifyParallelOp(ParallelOp op) { 169 if (op.allocate_vars().size() != op.allocators_vars().size()) 170 return op.emitError( 171 "expected equal sizes for allocate and allocator variables"); 172 return success(); 173 } 174 175 static void printParallelOp(OpAsmPrinter &p, ParallelOp op) { 176 p << " "; 177 if (auto ifCond = op.if_expr_var()) 178 p << "if(" << ifCond << " : " << ifCond.getType() << ") "; 179 180 if (auto threads = op.num_threads_var()) 181 p << "num_threads(" << threads << " : " << threads.getType() << ") "; 182 183 printDataVars(p, op.private_vars(), "private"); 184 printDataVars(p, op.firstprivate_vars(), "firstprivate"); 185 printDataVars(p, op.shared_vars(), "shared"); 186 printDataVars(p, op.copyin_vars(), "copyin"); 187 188 if (!op.allocate_vars().empty()) 189 printAllocateAndAllocator(p, op.allocate_vars(), op.allocators_vars()); 190 191 if (auto def = op.default_val()) 192 p << "default(" << stringifyClauseDefault(*def).drop_front(3) << ") "; 193 194 if (auto bind = op.proc_bind_val()) 195 p << "proc_bind(" << stringifyClauseProcBindKind(*bind) << ") "; 196 197 p << ' '; 198 p.printRegion(op.getRegion()); 199 } 200 201 static void printTargetOp(OpAsmPrinter &p, TargetOp op) { 202 p << " "; 203 if (auto ifCond = op.if_expr()) 204 p << "if(" << ifCond << " : " << ifCond.getType() << ") "; 205 206 if (auto device = op.device()) 207 p << "device(" << device << " : " << device.getType() << ") "; 208 209 if (auto threads = op.thread_limit()) 210 p << "thread_limit(" << threads << " : " << threads.getType() << ") "; 211 212 if (op.nowait()) { 213 p << "nowait "; 214 } 215 216 p.printRegion(op.getRegion()); 217 } 218 219 //===----------------------------------------------------------------------===// 220 // Parser and printer for Linear Clause 221 //===----------------------------------------------------------------------===// 222 223 /// linear ::= `linear` `(` linear-list `)` 224 /// linear-list := linear-val | linear-val linear-list 225 /// linear-val := ssa-id-and-type `=` ssa-id-and-type 226 static ParseResult 227 parseLinearClause(OpAsmParser &parser, 228 SmallVectorImpl<OpAsmParser::OperandType> &vars, 229 SmallVectorImpl<Type> &types, 230 SmallVectorImpl<OpAsmParser::OperandType> &stepVars) { 231 if (parser.parseLParen()) 232 return failure(); 233 234 do { 235 OpAsmParser::OperandType var; 236 Type type; 237 OpAsmParser::OperandType stepVar; 238 if (parser.parseOperand(var) || parser.parseEqual() || 239 parser.parseOperand(stepVar) || parser.parseColonType(type)) 240 return failure(); 241 242 vars.push_back(var); 243 types.push_back(type); 244 stepVars.push_back(stepVar); 245 } while (succeeded(parser.parseOptionalComma())); 246 247 if (parser.parseRParen()) 248 return failure(); 249 250 return success(); 251 } 252 253 /// Print Linear Clause 254 static void printLinearClause(OpAsmPrinter &p, OperandRange linearVars, 255 OperandRange linearStepVars) { 256 size_t linearVarsSize = linearVars.size(); 257 p << "linear("; 258 for (unsigned i = 0; i < linearVarsSize; ++i) { 259 std::string separator = i == linearVarsSize - 1 ? ") " : ", "; 260 p << linearVars[i]; 261 if (linearStepVars.size() > i) 262 p << " = " << linearStepVars[i]; 263 p << " : " << linearVars[i].getType() << separator; 264 } 265 } 266 267 //===----------------------------------------------------------------------===// 268 // Parser and printer for Schedule Clause 269 //===----------------------------------------------------------------------===// 270 271 static ParseResult 272 verifyScheduleModifiers(OpAsmParser &parser, 273 SmallVectorImpl<SmallString<12>> &modifiers) { 274 if (modifiers.size() > 2) 275 return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)"; 276 for (const auto &mod : modifiers) { 277 // Translate the string. If it has no value, then it was not a valid 278 // modifier! 279 auto symbol = symbolizeScheduleModifier(mod); 280 if (!symbol.hasValue()) 281 return parser.emitError(parser.getNameLoc()) 282 << " unknown modifier type: " << mod; 283 } 284 285 // If we have one modifier that is "simd", then stick a "none" modiifer in 286 // index 0. 287 if (modifiers.size() == 1) { 288 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) { 289 modifiers.push_back(modifiers[0]); 290 modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none); 291 } 292 } else if (modifiers.size() == 2) { 293 // If there are two modifier: 294 // First modifier should not be simd, second one should be simd 295 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd || 296 symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd) 297 return parser.emitError(parser.getNameLoc()) 298 << " incorrect modifier order"; 299 } 300 return success(); 301 } 302 303 /// schedule ::= `schedule` `(` sched-list `)` 304 /// sched-list ::= sched-val | sched-val sched-list | 305 /// sched-val `,` sched-modifier 306 /// sched-val ::= sched-with-chunk | sched-wo-chunk 307 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)? 308 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided` 309 /// sched-wo-chunk ::= `auto` | `runtime` 310 /// sched-modifier ::= sched-mod-val | sched-mod-val `,` sched-mod-val 311 /// sched-mod-val ::= `monotonic` | `nonmonotonic` | `simd` | `none` 312 static ParseResult 313 parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule, 314 SmallVectorImpl<SmallString<12>> &modifiers, 315 Optional<OpAsmParser::OperandType> &chunkSize, 316 Type &chunkType) { 317 if (parser.parseLParen()) 318 return failure(); 319 320 StringRef keyword; 321 if (parser.parseKeyword(&keyword)) 322 return failure(); 323 324 schedule = keyword; 325 if (keyword == "static" || keyword == "dynamic" || keyword == "guided") { 326 if (succeeded(parser.parseOptionalEqual())) { 327 chunkSize = OpAsmParser::OperandType{}; 328 if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType)) 329 return failure(); 330 } else { 331 chunkSize = llvm::NoneType::None; 332 } 333 } else if (keyword == "auto" || keyword == "runtime") { 334 chunkSize = llvm::NoneType::None; 335 } else { 336 return parser.emitError(parser.getNameLoc()) << " expected schedule kind"; 337 } 338 339 // If there is a comma, we have one or more modifiers.. 340 while (succeeded(parser.parseOptionalComma())) { 341 StringRef mod; 342 if (parser.parseKeyword(&mod)) 343 return failure(); 344 modifiers.push_back(mod); 345 } 346 347 if (parser.parseRParen()) 348 return failure(); 349 350 if (verifyScheduleModifiers(parser, modifiers)) 351 return failure(); 352 353 return success(); 354 } 355 356 /// Print schedule clause 357 static void printScheduleClause(OpAsmPrinter &p, ClauseScheduleKind sched, 358 Optional<ScheduleModifier> modifier, bool simd, 359 Value scheduleChunkVar) { 360 p << "schedule(" << stringifyClauseScheduleKind(sched).lower(); 361 if (scheduleChunkVar) 362 p << " = " << scheduleChunkVar << " : " << scheduleChunkVar.getType(); 363 if (modifier) 364 p << ", " << stringifyScheduleModifier(*modifier); 365 if (simd) 366 p << ", simd"; 367 p << ") "; 368 } 369 370 //===----------------------------------------------------------------------===// 371 // Parser, printer and verifier for ReductionVarList 372 //===----------------------------------------------------------------------===// 373 374 /// reduction ::= `reduction` `(` reduction-entry-list `)` 375 /// reduction-entry-list ::= reduction-entry 376 /// | reduction-entry-list `,` reduction-entry 377 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type 378 static ParseResult 379 parseReductionVarList(OpAsmParser &parser, 380 SmallVectorImpl<SymbolRefAttr> &symbols, 381 SmallVectorImpl<OpAsmParser::OperandType> &operands, 382 SmallVectorImpl<Type> &types) { 383 if (failed(parser.parseLParen())) 384 return failure(); 385 386 do { 387 if (parser.parseAttribute(symbols.emplace_back()) || parser.parseArrow() || 388 parser.parseOperand(operands.emplace_back()) || 389 parser.parseColonType(types.emplace_back())) 390 return failure(); 391 } while (succeeded(parser.parseOptionalComma())); 392 return parser.parseRParen(); 393 } 394 395 /// Print Reduction clause 396 static void printReductionVarList(OpAsmPrinter &p, 397 Optional<ArrayAttr> reductions, 398 OperandRange reductionVars) { 399 p << "reduction("; 400 for (unsigned i = 0, e = reductions->size(); i < e; ++i) { 401 if (i != 0) 402 p << ", "; 403 p << (*reductions)[i] << " -> " << reductionVars[i] << " : " 404 << reductionVars[i].getType(); 405 } 406 p << ") "; 407 } 408 409 /// Verifies Reduction Clause 410 static LogicalResult verifyReductionVarList(Operation *op, 411 Optional<ArrayAttr> reductions, 412 OperandRange reductionVars) { 413 if (!reductionVars.empty()) { 414 if (!reductions || reductions->size() != reductionVars.size()) 415 return op->emitOpError() 416 << "expected as many reduction symbol references " 417 "as reduction variables"; 418 } else { 419 if (reductions) 420 return op->emitOpError() << "unexpected reduction symbol references"; 421 return success(); 422 } 423 424 DenseSet<Value> accumulators; 425 for (auto args : llvm::zip(reductionVars, *reductions)) { 426 Value accum = std::get<0>(args); 427 428 if (!accumulators.insert(accum).second) 429 return op->emitOpError() << "accumulator variable used more than once"; 430 431 Type varType = accum.getType().cast<PointerLikeType>(); 432 auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>(); 433 auto decl = 434 SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef); 435 if (!decl) 436 return op->emitOpError() << "expected symbol reference " << symbolRef 437 << " to point to a reduction declaration"; 438 439 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType) 440 return op->emitOpError() 441 << "expected accumulator (" << varType 442 << ") to be the same type as reduction declaration (" 443 << decl.getAccumulatorType() << ")"; 444 } 445 446 return success(); 447 } 448 449 //===----------------------------------------------------------------------===// 450 // Parser, printer and verifier for Synchronization Hint (2.17.12) 451 //===----------------------------------------------------------------------===// 452 453 /// Parses a Synchronization Hint clause. The value of hint is an integer 454 /// which is a combination of different hints from `omp_sync_hint_t`. 455 /// 456 /// hint-clause = `hint` `(` hint-value `)` 457 static ParseResult parseSynchronizationHint(OpAsmParser &parser, 458 IntegerAttr &hintAttr, 459 bool parseKeyword = true) { 460 if (parseKeyword && failed(parser.parseOptionalKeyword("hint"))) { 461 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0); 462 return success(); 463 } 464 465 if (failed(parser.parseLParen())) 466 return failure(); 467 StringRef hintKeyword; 468 int64_t hint = 0; 469 do { 470 if (failed(parser.parseKeyword(&hintKeyword))) 471 return failure(); 472 if (hintKeyword == "uncontended") 473 hint |= 1; 474 else if (hintKeyword == "contended") 475 hint |= 2; 476 else if (hintKeyword == "nonspeculative") 477 hint |= 4; 478 else if (hintKeyword == "speculative") 479 hint |= 8; 480 else 481 return parser.emitError(parser.getCurrentLocation()) 482 << hintKeyword << " is not a valid hint"; 483 } while (succeeded(parser.parseOptionalComma())); 484 if (failed(parser.parseRParen())) 485 return failure(); 486 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint); 487 return success(); 488 } 489 490 /// Prints a Synchronization Hint clause 491 static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, 492 IntegerAttr hintAttr) { 493 int64_t hint = hintAttr.getInt(); 494 495 if (hint == 0) 496 return; 497 498 // Helper function to get n-th bit from the right end of `value` 499 auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; 500 501 bool uncontended = bitn(hint, 0); 502 bool contended = bitn(hint, 1); 503 bool nonspeculative = bitn(hint, 2); 504 bool speculative = bitn(hint, 3); 505 506 SmallVector<StringRef> hints; 507 if (uncontended) 508 hints.push_back("uncontended"); 509 if (contended) 510 hints.push_back("contended"); 511 if (nonspeculative) 512 hints.push_back("nonspeculative"); 513 if (speculative) 514 hints.push_back("speculative"); 515 516 p << "hint("; 517 llvm::interleaveComma(hints, p); 518 p << ") "; 519 } 520 521 /// Verifies a synchronization hint clause 522 static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) { 523 524 // Helper function to get n-th bit from the right end of `value` 525 auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; 526 527 bool uncontended = bitn(hint, 0); 528 bool contended = bitn(hint, 1); 529 bool nonspeculative = bitn(hint, 2); 530 bool speculative = bitn(hint, 3); 531 532 if (uncontended && contended) 533 return op->emitOpError() << "the hints omp_sync_hint_uncontended and " 534 "omp_sync_hint_contended cannot be combined"; 535 if (nonspeculative && speculative) 536 return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and " 537 "omp_sync_hint_speculative cannot be combined."; 538 return success(); 539 } 540 541 enum ClauseType { 542 ifClause, 543 numThreadsClause, 544 deviceClause, 545 threadLimitClause, 546 privateClause, 547 firstprivateClause, 548 lastprivateClause, 549 sharedClause, 550 copyinClause, 551 allocateClause, 552 defaultClause, 553 procBindClause, 554 reductionClause, 555 nowaitClause, 556 linearClause, 557 scheduleClause, 558 collapseClause, 559 orderClause, 560 orderedClause, 561 memoryOrderClause, 562 hintClause, 563 COUNT 564 }; 565 566 //===----------------------------------------------------------------------===// 567 // Parser for Clause List 568 //===----------------------------------------------------------------------===// 569 570 /// Parse a clause attribute `(` $value `)`. 571 template <typename ClauseAttr> 572 static ParseResult parseClauseAttr(AsmParser &parser, OperationState &state, 573 StringRef attrName, StringRef name) { 574 using ClauseT = decltype(std::declval<ClauseAttr>().getValue()); 575 StringRef enumStr; 576 SMLoc loc = parser.getCurrentLocation(); 577 if (parser.parseLParen() || parser.parseKeyword(&enumStr) || 578 parser.parseRParen()) 579 return failure(); 580 if (Optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) { 581 auto attr = ClauseAttr::get(parser.getContext(), *enumValue); 582 state.addAttribute(attrName, attr); 583 return success(); 584 } 585 return parser.emitError(loc, "invalid ") << name << " kind"; 586 } 587 588 /// Parse a list of clauses. The clauses can appear in any order, but their 589 /// operand segment indices are in the same order that they are passed in the 590 /// `clauses` list. The operand segments are added over the prevSegments 591 592 /// clause-list ::= clause clause-list | empty 593 /// clause ::= if | num-threads | private | firstprivate | lastprivate | 594 /// shared | copyin | allocate | default | proc-bind | reduction | 595 /// nowait | linear | schedule | collapse | order | ordered | 596 /// inclusive 597 /// if ::= `if` `(` ssa-id-and-type `)` 598 /// num-threads ::= `num_threads` `(` ssa-id-and-type `)` 599 /// private ::= `private` operand-and-type-list 600 /// firstprivate ::= `firstprivate` operand-and-type-list 601 /// lastprivate ::= `lastprivate` operand-and-type-list 602 /// shared ::= `shared` operand-and-type-list 603 /// copyin ::= `copyin` operand-and-type-list 604 /// allocate ::= `allocate` `(` allocate-operand-list `)` 605 /// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`) 606 /// proc-bind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)` 607 /// reduction ::= `reduction` `(` reduction-entry-list `)` 608 /// nowait ::= `nowait` 609 /// linear ::= `linear` `(` linear-list `)` 610 /// schedule ::= `schedule` `(` sched-list `)` 611 /// collapse ::= `collapse` `(` ssa-id-and-type `)` 612 /// order ::= `order` `(` `concurrent` `)` 613 /// ordered ::= `ordered` `(` ssa-id-and-type `)` 614 /// inclusive ::= `inclusive` 615 /// 616 /// Note that each clause can only appear once in the clase-list. 617 static ParseResult parseClauses(OpAsmParser &parser, OperationState &result, 618 SmallVectorImpl<ClauseType> &clauses, 619 SmallVectorImpl<int> &segments) { 620 621 // Check done[clause] to see if it has been parsed already 622 BitVector done(ClauseType::COUNT, false); 623 624 // See pos[clause] to get position of clause in operand segments 625 SmallVector<int> pos(ClauseType::COUNT, -1); 626 627 // Stores the last parsed clause keyword 628 StringRef clauseKeyword; 629 StringRef opName = result.name.getStringRef(); 630 631 // Containers for storing operands, types and attributes for various clauses 632 std::pair<OpAsmParser::OperandType, Type> ifCond; 633 std::pair<OpAsmParser::OperandType, Type> numThreads; 634 std::pair<OpAsmParser::OperandType, Type> device; 635 std::pair<OpAsmParser::OperandType, Type> threadLimit; 636 637 SmallVector<OpAsmParser::OperandType> privates, firstprivates, lastprivates, 638 shareds, copyins; 639 SmallVector<Type> privateTypes, firstprivateTypes, lastprivateTypes, 640 sharedTypes, copyinTypes; 641 642 SmallVector<OpAsmParser::OperandType> allocates, allocators; 643 SmallVector<Type> allocateTypes, allocatorTypes; 644 645 SmallVector<SymbolRefAttr> reductionSymbols; 646 SmallVector<OpAsmParser::OperandType> reductionVars; 647 SmallVector<Type> reductionVarTypes; 648 649 SmallVector<OpAsmParser::OperandType> linears; 650 SmallVector<Type> linearTypes; 651 SmallVector<OpAsmParser::OperandType> linearSteps; 652 653 SmallString<8> schedule; 654 SmallVector<SmallString<12>> modifiers; 655 Optional<OpAsmParser::OperandType> scheduleChunkSize; 656 Type scheduleChunkType; 657 658 // Compute the position of clauses in operand segments 659 int currPos = 0; 660 for (ClauseType clause : clauses) { 661 662 // Skip the following clauses - they do not take any position in operand 663 // segments 664 if (clause == defaultClause || clause == procBindClause || 665 clause == nowaitClause || clause == collapseClause || 666 clause == orderClause || clause == orderedClause) 667 continue; 668 669 pos[clause] = currPos++; 670 671 // For the following clauses, two positions are reserved in the operand 672 // segments 673 if (clause == allocateClause || clause == linearClause) 674 currPos++; 675 } 676 677 SmallVector<int> clauseSegments(currPos); 678 679 // Helper function to check if a clause is allowed/repeated or not 680 auto checkAllowed = [&](ClauseType clause) -> ParseResult { 681 if (!llvm::is_contained(clauses, clause)) 682 return parser.emitError(parser.getCurrentLocation()) 683 << clauseKeyword << " is not a valid clause for the " << opName 684 << " operation"; 685 if (done[clause]) 686 return parser.emitError(parser.getCurrentLocation()) 687 << "at most one " << clauseKeyword << " clause can appear on the " 688 << opName << " operation"; 689 done[clause] = true; 690 return success(); 691 }; 692 693 while (succeeded(parser.parseOptionalKeyword(&clauseKeyword))) { 694 if (clauseKeyword == "if") { 695 if (checkAllowed(ifClause) || parser.parseLParen() || 696 parser.parseOperand(ifCond.first) || 697 parser.parseColonType(ifCond.second) || parser.parseRParen()) 698 return failure(); 699 clauseSegments[pos[ifClause]] = 1; 700 } else if (clauseKeyword == "num_threads") { 701 if (checkAllowed(numThreadsClause) || parser.parseLParen() || 702 parser.parseOperand(numThreads.first) || 703 parser.parseColonType(numThreads.second) || parser.parseRParen()) 704 return failure(); 705 clauseSegments[pos[numThreadsClause]] = 1; 706 } else if (clauseKeyword == "device") { 707 if (checkAllowed(deviceClause) || parser.parseLParen() || 708 parser.parseOperand(device.first) || 709 parser.parseColonType(device.second) || parser.parseRParen()) 710 return failure(); 711 clauseSegments[pos[deviceClause]] = 1; 712 } else if (clauseKeyword == "thread_limit") { 713 if (checkAllowed(threadLimitClause) || parser.parseLParen() || 714 parser.parseOperand(threadLimit.first) || 715 parser.parseColonType(threadLimit.second) || parser.parseRParen()) 716 return failure(); 717 clauseSegments[pos[threadLimitClause]] = 1; 718 } else if (clauseKeyword == "private") { 719 if (checkAllowed(privateClause) || 720 parseOperandAndTypeList(parser, privates, privateTypes)) 721 return failure(); 722 clauseSegments[pos[privateClause]] = privates.size(); 723 } else if (clauseKeyword == "firstprivate") { 724 if (checkAllowed(firstprivateClause) || 725 parseOperandAndTypeList(parser, firstprivates, firstprivateTypes)) 726 return failure(); 727 clauseSegments[pos[firstprivateClause]] = firstprivates.size(); 728 } else if (clauseKeyword == "lastprivate") { 729 if (checkAllowed(lastprivateClause) || 730 parseOperandAndTypeList(parser, lastprivates, lastprivateTypes)) 731 return failure(); 732 clauseSegments[pos[lastprivateClause]] = lastprivates.size(); 733 } else if (clauseKeyword == "shared") { 734 if (checkAllowed(sharedClause) || 735 parseOperandAndTypeList(parser, shareds, sharedTypes)) 736 return failure(); 737 clauseSegments[pos[sharedClause]] = shareds.size(); 738 } else if (clauseKeyword == "copyin") { 739 if (checkAllowed(copyinClause) || 740 parseOperandAndTypeList(parser, copyins, copyinTypes)) 741 return failure(); 742 clauseSegments[pos[copyinClause]] = copyins.size(); 743 } else if (clauseKeyword == "allocate") { 744 if (checkAllowed(allocateClause) || 745 parseAllocateAndAllocator(parser, allocates, allocateTypes, 746 allocators, allocatorTypes)) 747 return failure(); 748 clauseSegments[pos[allocateClause]] = allocates.size(); 749 clauseSegments[pos[allocateClause] + 1] = allocators.size(); 750 } else if (clauseKeyword == "default") { 751 StringRef defval; 752 SMLoc loc = parser.getCurrentLocation(); 753 if (checkAllowed(defaultClause) || parser.parseLParen() || 754 parser.parseKeyword(&defval) || parser.parseRParen()) 755 return failure(); 756 // The def prefix is required for the attribute as "private" is a keyword 757 // in C++. 758 if (Optional<ClauseDefault> def = 759 symbolizeClauseDefault(("def" + defval).str())) { 760 result.addAttribute("default_val", 761 ClauseDefaultAttr::get(parser.getContext(), *def)); 762 } else { 763 return parser.emitError(loc, "invalid default clause"); 764 } 765 } else if (clauseKeyword == "proc_bind") { 766 if (checkAllowed(procBindClause) || 767 parseClauseAttr<ClauseProcBindKindAttr>(parser, result, 768 "proc_bind_val", "proc bind")) 769 return failure(); 770 } else if (clauseKeyword == "reduction") { 771 if (checkAllowed(reductionClause) || 772 parseReductionVarList(parser, reductionSymbols, reductionVars, 773 reductionVarTypes)) 774 return failure(); 775 clauseSegments[pos[reductionClause]] = reductionVars.size(); 776 } else if (clauseKeyword == "nowait") { 777 if (checkAllowed(nowaitClause)) 778 return failure(); 779 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 780 result.addAttribute("nowait", attr); 781 } else if (clauseKeyword == "linear") { 782 if (checkAllowed(linearClause) || 783 parseLinearClause(parser, linears, linearTypes, linearSteps)) 784 return failure(); 785 clauseSegments[pos[linearClause]] = linears.size(); 786 clauseSegments[pos[linearClause] + 1] = linearSteps.size(); 787 } else if (clauseKeyword == "schedule") { 788 if (checkAllowed(scheduleClause) || 789 parseScheduleClause(parser, schedule, modifiers, scheduleChunkSize, 790 scheduleChunkType)) 791 return failure(); 792 if (scheduleChunkSize) { 793 clauseSegments[pos[scheduleClause]] = 1; 794 } 795 } else if (clauseKeyword == "collapse") { 796 auto type = parser.getBuilder().getI64Type(); 797 mlir::IntegerAttr attr; 798 if (checkAllowed(collapseClause) || parser.parseLParen() || 799 parser.parseAttribute(attr, type) || parser.parseRParen()) 800 return failure(); 801 result.addAttribute("collapse_val", attr); 802 } else if (clauseKeyword == "ordered") { 803 mlir::IntegerAttr attr; 804 if (checkAllowed(orderedClause)) 805 return failure(); 806 if (succeeded(parser.parseOptionalLParen())) { 807 auto type = parser.getBuilder().getI64Type(); 808 if (parser.parseAttribute(attr, type) || parser.parseRParen()) 809 return failure(); 810 } else { 811 // Use 0 to represent no ordered parameter was specified 812 attr = parser.getBuilder().getI64IntegerAttr(0); 813 } 814 result.addAttribute("ordered_val", attr); 815 } else if (clauseKeyword == "order") { 816 if (checkAllowed(orderClause) || 817 parseClauseAttr<ClauseOrderKindAttr>(parser, result, "order_val", 818 "order")) 819 return failure(); 820 } else if (clauseKeyword == "memory_order") { 821 if (checkAllowed(memoryOrderClause) || 822 parseClauseAttr<ClauseMemoryOrderKindAttr>( 823 parser, result, "memory_order", "memory order")) 824 return failure(); 825 } else if (clauseKeyword == "hint") { 826 IntegerAttr hint; 827 if (checkAllowed(hintClause) || 828 parseSynchronizationHint(parser, hint, false)) 829 return failure(); 830 result.addAttribute("hint", hint); 831 } else { 832 return parser.emitError(parser.getNameLoc()) 833 << clauseKeyword << " is not a valid clause"; 834 } 835 } 836 837 // Add if parameter. 838 if (done[ifClause] && clauseSegments[pos[ifClause]] && 839 failed( 840 parser.resolveOperand(ifCond.first, ifCond.second, result.operands))) 841 return failure(); 842 843 // Add num_threads parameter. 844 if (done[numThreadsClause] && clauseSegments[pos[numThreadsClause]] && 845 failed(parser.resolveOperand(numThreads.first, numThreads.second, 846 result.operands))) 847 return failure(); 848 849 // Add device parameter. 850 if (done[deviceClause] && clauseSegments[pos[deviceClause]] && 851 failed( 852 parser.resolveOperand(device.first, device.second, result.operands))) 853 return failure(); 854 855 // Add thread_limit parameter. 856 if (done[threadLimitClause] && clauseSegments[pos[threadLimitClause]] && 857 failed(parser.resolveOperand(threadLimit.first, threadLimit.second, 858 result.operands))) 859 return failure(); 860 861 // Add private parameters. 862 if (done[privateClause] && clauseSegments[pos[privateClause]] && 863 failed(parser.resolveOperands(privates, privateTypes, 864 privates[0].location, result.operands))) 865 return failure(); 866 867 // Add firstprivate parameters. 868 if (done[firstprivateClause] && clauseSegments[pos[firstprivateClause]] && 869 failed(parser.resolveOperands(firstprivates, firstprivateTypes, 870 firstprivates[0].location, 871 result.operands))) 872 return failure(); 873 874 // Add lastprivate parameters. 875 if (done[lastprivateClause] && clauseSegments[pos[lastprivateClause]] && 876 failed(parser.resolveOperands(lastprivates, lastprivateTypes, 877 lastprivates[0].location, result.operands))) 878 return failure(); 879 880 // Add shared parameters. 881 if (done[sharedClause] && clauseSegments[pos[sharedClause]] && 882 failed(parser.resolveOperands(shareds, sharedTypes, shareds[0].location, 883 result.operands))) 884 return failure(); 885 886 // Add copyin parameters. 887 if (done[copyinClause] && clauseSegments[pos[copyinClause]] && 888 failed(parser.resolveOperands(copyins, copyinTypes, copyins[0].location, 889 result.operands))) 890 return failure(); 891 892 // Add allocate parameters. 893 if (done[allocateClause] && clauseSegments[pos[allocateClause]] && 894 failed(parser.resolveOperands(allocates, allocateTypes, 895 allocates[0].location, result.operands))) 896 return failure(); 897 898 // Add allocator parameters. 899 if (done[allocateClause] && clauseSegments[pos[allocateClause] + 1] && 900 failed(parser.resolveOperands(allocators, allocatorTypes, 901 allocators[0].location, result.operands))) 902 return failure(); 903 904 // Add reduction parameters and symbols 905 if (done[reductionClause] && clauseSegments[pos[reductionClause]]) { 906 if (failed(parser.resolveOperands(reductionVars, reductionVarTypes, 907 parser.getNameLoc(), result.operands))) 908 return failure(); 909 910 SmallVector<Attribute> reductions(reductionSymbols.begin(), 911 reductionSymbols.end()); 912 result.addAttribute("reductions", 913 parser.getBuilder().getArrayAttr(reductions)); 914 } 915 916 // Add linear parameters 917 if (done[linearClause] && clauseSegments[pos[linearClause]]) { 918 auto linearStepType = parser.getBuilder().getI32Type(); 919 SmallVector<Type> linearStepTypes(linearSteps.size(), linearStepType); 920 if (failed(parser.resolveOperands(linears, linearTypes, linears[0].location, 921 result.operands)) || 922 failed(parser.resolveOperands(linearSteps, linearStepTypes, 923 linearSteps[0].location, 924 result.operands))) 925 return failure(); 926 } 927 928 // Add schedule parameters 929 if (done[scheduleClause] && !schedule.empty()) { 930 schedule[0] = llvm::toUpper(schedule[0]); 931 if (Optional<ClauseScheduleKind> sched = 932 symbolizeClauseScheduleKind(schedule)) { 933 auto attr = ClauseScheduleKindAttr::get(parser.getContext(), *sched); 934 result.addAttribute("schedule_val", attr); 935 } else { 936 return parser.emitError(parser.getCurrentLocation(), 937 "invalid schedule kind"); 938 } 939 if (!modifiers.empty()) { 940 SMLoc loc = parser.getCurrentLocation(); 941 if (Optional<ScheduleModifier> mod = 942 symbolizeScheduleModifier(modifiers[0])) { 943 result.addAttribute( 944 "schedule_modifier", 945 ScheduleModifierAttr::get(parser.getContext(), *mod)); 946 } else { 947 return parser.emitError(loc, "invalid schedule modifier"); 948 } 949 // Only SIMD attribute is allowed here! 950 if (modifiers.size() > 1) { 951 assert(symbolizeScheduleModifier(modifiers[1]) == 952 ScheduleModifier::simd); 953 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 954 result.addAttribute("simd_modifier", attr); 955 } 956 } 957 if (scheduleChunkSize) 958 parser.resolveOperand(*scheduleChunkSize, scheduleChunkType, 959 result.operands); 960 } 961 962 segments.insert(segments.end(), clauseSegments.begin(), clauseSegments.end()); 963 964 return success(); 965 } 966 967 /// Parses a parallel operation. 968 /// 969 /// operation ::= `omp.parallel` clause-list 970 /// clause-list ::= clause | clause clause-list 971 /// clause ::= if | num-threads | private | firstprivate | shared | copyin | 972 /// allocate | default | proc-bind 973 /// 974 static ParseResult parseParallelOp(OpAsmParser &parser, 975 OperationState &result) { 976 SmallVector<ClauseType> clauses = { 977 ifClause, numThreadsClause, privateClause, 978 firstprivateClause, sharedClause, copyinClause, 979 allocateClause, defaultClause, procBindClause}; 980 981 SmallVector<int> segments; 982 983 if (failed(parseClauses(parser, result, clauses, segments))) 984 return failure(); 985 986 result.addAttribute("operand_segment_sizes", 987 parser.getBuilder().getI32VectorAttr(segments)); 988 989 Region *body = result.addRegion(); 990 SmallVector<OpAsmParser::OperandType> regionArgs; 991 SmallVector<Type> regionArgTypes; 992 if (parser.parseRegion(*body, regionArgs, regionArgTypes)) 993 return failure(); 994 return success(); 995 } 996 997 /// Parses a target operation. 998 /// 999 /// operation ::= `omp.target` clause-list 1000 /// clause-list ::= clause | clause clause-list 1001 /// clause ::= if | device | thread_limit | nowait 1002 /// 1003 static ParseResult parseTargetOp(OpAsmParser &parser, OperationState &result) { 1004 SmallVector<ClauseType> clauses = {ifClause, deviceClause, threadLimitClause, 1005 nowaitClause}; 1006 1007 SmallVector<int> segments; 1008 1009 if (failed(parseClauses(parser, result, clauses, segments))) 1010 return failure(); 1011 1012 result.addAttribute( 1013 TargetOp::AttrSizedOperandSegments::getOperandSegmentSizeAttr(), 1014 parser.getBuilder().getI32VectorAttr(segments)); 1015 1016 Region *body = result.addRegion(); 1017 SmallVector<OpAsmParser::OperandType> regionArgs; 1018 SmallVector<Type> regionArgTypes; 1019 if (parser.parseRegion(*body, regionArgs, regionArgTypes)) 1020 return failure(); 1021 return success(); 1022 } 1023 1024 //===----------------------------------------------------------------------===// 1025 // Parser, printer and verifier for SectionsOp 1026 //===----------------------------------------------------------------------===// 1027 1028 /// Parses an OpenMP Sections operation 1029 /// 1030 /// sections ::= `omp.sections` clause-list 1031 /// clause-list ::= clause clause-list | empty 1032 /// clause ::= private | firstprivate | lastprivate | reduction | allocate | 1033 /// nowait 1034 static ParseResult parseSectionsOp(OpAsmParser &parser, 1035 OperationState &result) { 1036 1037 SmallVector<ClauseType> clauses = {privateClause, firstprivateClause, 1038 lastprivateClause, reductionClause, 1039 allocateClause, nowaitClause}; 1040 1041 SmallVector<int> segments; 1042 1043 if (failed(parseClauses(parser, result, clauses, segments))) 1044 return failure(); 1045 1046 result.addAttribute("operand_segment_sizes", 1047 parser.getBuilder().getI32VectorAttr(segments)); 1048 1049 // Now parse the body. 1050 Region *body = result.addRegion(); 1051 if (parser.parseRegion(*body)) 1052 return failure(); 1053 return success(); 1054 } 1055 1056 static void printSectionsOp(OpAsmPrinter &p, SectionsOp op) { 1057 p << " "; 1058 printDataVars(p, op.private_vars(), "private"); 1059 printDataVars(p, op.firstprivate_vars(), "firstprivate"); 1060 printDataVars(p, op.lastprivate_vars(), "lastprivate"); 1061 1062 if (!op.reduction_vars().empty()) 1063 printReductionVarList(p, op.reductions(), op.reduction_vars()); 1064 1065 if (!op.allocate_vars().empty()) 1066 printAllocateAndAllocator(p, op.allocate_vars(), op.allocators_vars()); 1067 1068 if (op.nowait()) 1069 p << "nowait"; 1070 1071 p << ' '; 1072 p.printRegion(op.region()); 1073 } 1074 1075 static LogicalResult verifySectionsOp(SectionsOp op) { 1076 1077 // A list item may not appear in more than one clause on the same directive, 1078 // except that it may be specified in both firstprivate and lastprivate 1079 // clauses. 1080 for (auto var : op.private_vars()) { 1081 if (llvm::is_contained(op.firstprivate_vars(), var)) 1082 return op.emitOpError() 1083 << "operand used in both private and firstprivate clauses"; 1084 if (llvm::is_contained(op.lastprivate_vars(), var)) 1085 return op.emitOpError() 1086 << "operand used in both private and lastprivate clauses"; 1087 } 1088 1089 if (op.allocate_vars().size() != op.allocators_vars().size()) 1090 return op.emitError( 1091 "expected equal sizes for allocate and allocator variables"); 1092 1093 for (auto &inst : *op.region().begin()) { 1094 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) 1095 op.emitOpError() 1096 << "expected omp.section op or terminator op inside region"; 1097 } 1098 1099 return verifyReductionVarList(op, op.reductions(), op.reduction_vars()); 1100 } 1101 1102 /// Parses an OpenMP Workshare Loop operation 1103 /// 1104 /// wsloop ::= `omp.wsloop` loop-control clause-list 1105 /// loop-control ::= `(` ssa-id-list `)` `:` type `=` loop-bounds 1106 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` inclusive? steps 1107 /// steps := `step` `(`ssa-id-list`)` 1108 /// clause-list ::= clause clause-list | empty 1109 /// clause ::= private | firstprivate | lastprivate | linear | schedule | 1110 // collapse | nowait | ordered | order | reduction 1111 static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) { 1112 1113 // Parse an opening `(` followed by induction variables followed by `)` 1114 SmallVector<OpAsmParser::OperandType> ivs; 1115 if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1, 1116 OpAsmParser::Delimiter::Paren)) 1117 return failure(); 1118 1119 int numIVs = static_cast<int>(ivs.size()); 1120 Type loopVarType; 1121 if (parser.parseColonType(loopVarType)) 1122 return failure(); 1123 1124 // Parse loop bounds. 1125 SmallVector<OpAsmParser::OperandType> lower; 1126 if (parser.parseEqual() || 1127 parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) || 1128 parser.resolveOperands(lower, loopVarType, result.operands)) 1129 return failure(); 1130 1131 SmallVector<OpAsmParser::OperandType> upper; 1132 if (parser.parseKeyword("to") || 1133 parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) || 1134 parser.resolveOperands(upper, loopVarType, result.operands)) 1135 return failure(); 1136 1137 if (succeeded(parser.parseOptionalKeyword("inclusive"))) { 1138 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 1139 result.addAttribute("inclusive", attr); 1140 } 1141 1142 // Parse step values. 1143 SmallVector<OpAsmParser::OperandType> steps; 1144 if (parser.parseKeyword("step") || 1145 parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) || 1146 parser.resolveOperands(steps, loopVarType, result.operands)) 1147 return failure(); 1148 1149 SmallVector<ClauseType> clauses = { 1150 privateClause, firstprivateClause, lastprivateClause, linearClause, 1151 reductionClause, collapseClause, orderClause, orderedClause, 1152 nowaitClause, scheduleClause}; 1153 SmallVector<int> segments{numIVs, numIVs, numIVs}; 1154 if (failed(parseClauses(parser, result, clauses, segments))) 1155 return failure(); 1156 1157 result.addAttribute("operand_segment_sizes", 1158 parser.getBuilder().getI32VectorAttr(segments)); 1159 1160 // Now parse the body. 1161 Region *body = result.addRegion(); 1162 SmallVector<Type> ivTypes(numIVs, loopVarType); 1163 SmallVector<OpAsmParser::OperandType> blockArgs(ivs); 1164 if (parser.parseRegion(*body, blockArgs, ivTypes)) 1165 return failure(); 1166 return success(); 1167 } 1168 1169 static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) { 1170 auto args = op.getRegion().front().getArguments(); 1171 p << " (" << args << ") : " << args[0].getType() << " = (" << op.lowerBound() 1172 << ") to (" << op.upperBound() << ") "; 1173 if (op.inclusive()) { 1174 p << "inclusive "; 1175 } 1176 p << "step (" << op.step() << ") "; 1177 1178 printDataVars(p, op.private_vars(), "private"); 1179 printDataVars(p, op.firstprivate_vars(), "firstprivate"); 1180 printDataVars(p, op.lastprivate_vars(), "lastprivate"); 1181 1182 if (!op.linear_vars().empty()) 1183 printLinearClause(p, op.linear_vars(), op.linear_step_vars()); 1184 1185 if (auto sched = op.schedule_val()) 1186 printScheduleClause(p, sched.getValue(), op.schedule_modifier(), 1187 op.simd_modifier(), op.schedule_chunk_var()); 1188 1189 if (auto collapse = op.collapse_val()) 1190 p << "collapse(" << collapse << ") "; 1191 1192 if (op.nowait()) 1193 p << "nowait "; 1194 1195 if (auto ordered = op.ordered_val()) 1196 p << "ordered(" << ordered << ") "; 1197 1198 if (auto order = op.order_val()) 1199 p << "order(" << stringifyClauseOrderKind(*order) << ") "; 1200 1201 if (!op.reduction_vars().empty()) 1202 printReductionVarList(p, op.reductions(), op.reduction_vars()); 1203 1204 p << ' '; 1205 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 1206 } 1207 1208 //===----------------------------------------------------------------------===// 1209 // ReductionOp 1210 //===----------------------------------------------------------------------===// 1211 1212 static ParseResult parseAtomicReductionRegion(OpAsmParser &parser, 1213 Region ®ion) { 1214 if (parser.parseOptionalKeyword("atomic")) 1215 return success(); 1216 return parser.parseRegion(region); 1217 } 1218 1219 static void printAtomicReductionRegion(OpAsmPrinter &printer, 1220 ReductionDeclareOp op, Region ®ion) { 1221 if (region.empty()) 1222 return; 1223 printer << "atomic "; 1224 printer.printRegion(region); 1225 } 1226 1227 static LogicalResult verifyReductionDeclareOp(ReductionDeclareOp op) { 1228 if (op.initializerRegion().empty()) 1229 return op.emitOpError() << "expects non-empty initializer region"; 1230 Block &initializerEntryBlock = op.initializerRegion().front(); 1231 if (initializerEntryBlock.getNumArguments() != 1 || 1232 initializerEntryBlock.getArgument(0).getType() != op.type()) { 1233 return op.emitOpError() << "expects initializer region with one argument " 1234 "of the reduction type"; 1235 } 1236 1237 for (YieldOp yieldOp : op.initializerRegion().getOps<YieldOp>()) { 1238 if (yieldOp.results().size() != 1 || 1239 yieldOp.results().getTypes()[0] != op.type()) 1240 return op.emitOpError() << "expects initializer region to yield a value " 1241 "of the reduction type"; 1242 } 1243 1244 if (op.reductionRegion().empty()) 1245 return op.emitOpError() << "expects non-empty reduction region"; 1246 Block &reductionEntryBlock = op.reductionRegion().front(); 1247 if (reductionEntryBlock.getNumArguments() != 2 || 1248 reductionEntryBlock.getArgumentTypes()[0] != 1249 reductionEntryBlock.getArgumentTypes()[1] || 1250 reductionEntryBlock.getArgumentTypes()[0] != op.type()) 1251 return op.emitOpError() << "expects reduction region with two arguments of " 1252 "the reduction type"; 1253 for (YieldOp yieldOp : op.reductionRegion().getOps<YieldOp>()) { 1254 if (yieldOp.results().size() != 1 || 1255 yieldOp.results().getTypes()[0] != op.type()) 1256 return op.emitOpError() << "expects reduction region to yield a value " 1257 "of the reduction type"; 1258 } 1259 1260 if (op.atomicReductionRegion().empty()) 1261 return success(); 1262 1263 Block &atomicReductionEntryBlock = op.atomicReductionRegion().front(); 1264 if (atomicReductionEntryBlock.getNumArguments() != 2 || 1265 atomicReductionEntryBlock.getArgumentTypes()[0] != 1266 atomicReductionEntryBlock.getArgumentTypes()[1]) 1267 return op.emitOpError() << "expects atomic reduction region with two " 1268 "arguments of the same type"; 1269 auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0] 1270 .dyn_cast<PointerLikeType>(); 1271 if (!ptrType || ptrType.getElementType() != op.type()) 1272 return op.emitOpError() << "expects atomic reduction region arguments to " 1273 "be accumulators containing the reduction type"; 1274 return success(); 1275 } 1276 1277 static LogicalResult verifyReductionOp(ReductionOp op) { 1278 // TODO: generalize this to an op interface when there is more than one op 1279 // that supports reductions. 1280 auto container = op->getParentOfType<WsLoopOp>(); 1281 for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i) 1282 if (container.reduction_vars()[i] == op.accumulator()) 1283 return success(); 1284 1285 return op.emitOpError() << "the accumulator is not used by the parent"; 1286 } 1287 1288 //===----------------------------------------------------------------------===// 1289 // WsLoopOp 1290 //===----------------------------------------------------------------------===// 1291 1292 void WsLoopOp::build(OpBuilder &builder, OperationState &state, 1293 ValueRange lowerBound, ValueRange upperBound, 1294 ValueRange step, ArrayRef<NamedAttribute> attributes) { 1295 build(builder, state, TypeRange(), lowerBound, upperBound, step, 1296 /*privateVars=*/ValueRange(), 1297 /*firstprivateVars=*/ValueRange(), /*lastprivate_vars=*/ValueRange(), 1298 /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(), 1299 /*reduction_vars=*/ValueRange(), /*schedule_val=*/nullptr, 1300 /*schedule_chunk_var=*/nullptr, /*collapse_val=*/nullptr, 1301 /*nowait=*/nullptr, /*ordered_val=*/nullptr, /*order_val=*/nullptr, 1302 /*inclusive=*/nullptr, /*buildBody=*/false); 1303 state.addAttributes(attributes); 1304 } 1305 1306 void WsLoopOp::build(OpBuilder &, OperationState &state, TypeRange resultTypes, 1307 ValueRange operands, ArrayRef<NamedAttribute> attributes) { 1308 state.addOperands(operands); 1309 state.addAttributes(attributes); 1310 (void)state.addRegion(); 1311 assert(resultTypes.empty() && "mismatched number of return types"); 1312 state.addTypes(resultTypes); 1313 } 1314 1315 void WsLoopOp::build(OpBuilder &builder, OperationState &result, 1316 TypeRange typeRange, ValueRange lowerBounds, 1317 ValueRange upperBounds, ValueRange steps, 1318 ValueRange privateVars, ValueRange firstprivateVars, 1319 ValueRange lastprivateVars, ValueRange linearVars, 1320 ValueRange linearStepVars, ValueRange reductionVars, 1321 StringAttr scheduleVal, Value scheduleChunkVar, 1322 IntegerAttr collapseVal, UnitAttr nowait, 1323 IntegerAttr orderedVal, StringAttr orderVal, 1324 UnitAttr inclusive, bool buildBody) { 1325 result.addOperands(lowerBounds); 1326 result.addOperands(upperBounds); 1327 result.addOperands(steps); 1328 result.addOperands(privateVars); 1329 result.addOperands(firstprivateVars); 1330 result.addOperands(linearVars); 1331 result.addOperands(linearStepVars); 1332 if (scheduleChunkVar) 1333 result.addOperands(scheduleChunkVar); 1334 1335 if (scheduleVal) 1336 result.addAttribute("schedule_val", scheduleVal); 1337 if (collapseVal) 1338 result.addAttribute("collapse_val", collapseVal); 1339 if (nowait) 1340 result.addAttribute("nowait", nowait); 1341 if (orderedVal) 1342 result.addAttribute("ordered_val", orderedVal); 1343 if (orderVal) 1344 result.addAttribute("order", orderVal); 1345 if (inclusive) 1346 result.addAttribute("inclusive", inclusive); 1347 result.addAttribute( 1348 WsLoopOp::getOperandSegmentSizeAttr(), 1349 builder.getI32VectorAttr( 1350 {static_cast<int32_t>(lowerBounds.size()), 1351 static_cast<int32_t>(upperBounds.size()), 1352 static_cast<int32_t>(steps.size()), 1353 static_cast<int32_t>(privateVars.size()), 1354 static_cast<int32_t>(firstprivateVars.size()), 1355 static_cast<int32_t>(lastprivateVars.size()), 1356 static_cast<int32_t>(linearVars.size()), 1357 static_cast<int32_t>(linearStepVars.size()), 1358 static_cast<int32_t>(reductionVars.size()), 1359 static_cast<int32_t>(scheduleChunkVar != nullptr ? 1 : 0)})); 1360 1361 Region *bodyRegion = result.addRegion(); 1362 if (buildBody) { 1363 OpBuilder::InsertionGuard guard(builder); 1364 unsigned numIVs = steps.size(); 1365 SmallVector<Type, 8> argTypes(numIVs, steps.getType().front()); 1366 SmallVector<Location, 8> argLocs(numIVs, result.location); 1367 builder.createBlock(bodyRegion, {}, argTypes, argLocs); 1368 } 1369 } 1370 1371 static LogicalResult verifyWsLoopOp(WsLoopOp op) { 1372 return verifyReductionVarList(op, op.reductions(), op.reduction_vars()); 1373 } 1374 1375 //===----------------------------------------------------------------------===// 1376 // Verifier for critical construct (2.17.1) 1377 //===----------------------------------------------------------------------===// 1378 1379 static LogicalResult verifyCriticalDeclareOp(CriticalDeclareOp op) { 1380 return verifySynchronizationHint(op, op.hint()); 1381 } 1382 1383 static LogicalResult verifyCriticalOp(CriticalOp op) { 1384 1385 if (op.nameAttr()) { 1386 auto symbolRef = op.nameAttr().cast<SymbolRefAttr>(); 1387 auto decl = 1388 SymbolTable::lookupNearestSymbolFrom<CriticalDeclareOp>(op, symbolRef); 1389 if (!decl) { 1390 return op.emitOpError() << "expected symbol reference " << symbolRef 1391 << " to point to a critical declaration"; 1392 } 1393 } 1394 1395 return success(); 1396 } 1397 1398 //===----------------------------------------------------------------------===// 1399 // Verifier for ordered construct 1400 //===----------------------------------------------------------------------===// 1401 1402 static LogicalResult verifyOrderedOp(OrderedOp op) { 1403 auto container = op->getParentOfType<WsLoopOp>(); 1404 if (!container || !container.ordered_valAttr() || 1405 container.ordered_valAttr().getInt() == 0) 1406 return op.emitOpError() << "ordered depend directive must be closely " 1407 << "nested inside a worksharing-loop with ordered " 1408 << "clause with parameter present"; 1409 1410 if (container.ordered_valAttr().getInt() != 1411 (int64_t)op.num_loops_val().getValue()) 1412 return op.emitOpError() << "number of variables in depend clause does not " 1413 << "match number of iteration variables in the " 1414 << "doacross loop"; 1415 1416 return success(); 1417 } 1418 1419 static LogicalResult verifyOrderedRegionOp(OrderedRegionOp op) { 1420 // TODO: The code generation for ordered simd directive is not supported yet. 1421 if (op.simd()) 1422 return failure(); 1423 1424 if (auto container = op->getParentOfType<WsLoopOp>()) { 1425 if (!container.ordered_valAttr() || 1426 container.ordered_valAttr().getInt() != 0) 1427 return op.emitOpError() << "ordered region must be closely nested inside " 1428 << "a worksharing-loop region with an ordered " 1429 << "clause without parameter present"; 1430 } 1431 1432 return success(); 1433 } 1434 1435 //===----------------------------------------------------------------------===// 1436 // AtomicReadOp 1437 //===----------------------------------------------------------------------===// 1438 1439 /// Parser for AtomicReadOp 1440 /// 1441 /// operation ::= `omp.atomic.read` atomic-clause-list address `->` result-type 1442 /// address ::= operand `:` type 1443 static ParseResult parseAtomicReadOp(OpAsmParser &parser, 1444 OperationState &result) { 1445 OpAsmParser::OperandType x, v; 1446 Type addressType; 1447 SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause}; 1448 SmallVector<int> segments; 1449 1450 if (parser.parseOperand(v) || parser.parseEqual() || parser.parseOperand(x) || 1451 parseClauses(parser, result, clauses, segments) || 1452 parser.parseColonType(addressType) || 1453 parser.resolveOperand(x, addressType, result.operands) || 1454 parser.resolveOperand(v, addressType, result.operands)) 1455 return failure(); 1456 1457 return success(); 1458 } 1459 1460 /// Printer for AtomicReadOp 1461 static void printAtomicReadOp(OpAsmPrinter &p, AtomicReadOp op) { 1462 p << " " << op.v() << " = " << op.x() << " "; 1463 if (auto mo = op.memory_order()) 1464 p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") "; 1465 if (op.hintAttr()) 1466 printSynchronizationHint(p << " ", op, op.hintAttr()); 1467 p << ": " << op.x().getType(); 1468 } 1469 1470 /// Verifier for AtomicReadOp 1471 static LogicalResult verifyAtomicReadOp(AtomicReadOp op) { 1472 if (auto mo = op.memory_order()) { 1473 if (*mo == ClauseMemoryOrderKind::acq_rel || 1474 *mo == ClauseMemoryOrderKind::release) { 1475 return op.emitError( 1476 "memory-order must not be acq_rel or release for atomic reads"); 1477 } 1478 } 1479 if (op.x() == op.v()) 1480 return op.emitError( 1481 "read and write must not be to the same location for atomic reads"); 1482 return verifySynchronizationHint(op, op.hint()); 1483 } 1484 1485 //===----------------------------------------------------------------------===// 1486 // AtomicWriteOp 1487 //===----------------------------------------------------------------------===// 1488 1489 /// Parser for AtomicWriteOp 1490 /// 1491 /// operation ::= `omp.atomic.write` atomic-clause-list operands 1492 /// operands ::= address `,` value 1493 /// address ::= operand `:` type 1494 /// value ::= operand `:` type 1495 static ParseResult parseAtomicWriteOp(OpAsmParser &parser, 1496 OperationState &result) { 1497 OpAsmParser::OperandType address, value; 1498 Type addrType, valueType; 1499 SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause}; 1500 SmallVector<int> segments; 1501 1502 if (parser.parseOperand(address) || parser.parseEqual() || 1503 parser.parseOperand(value) || 1504 parseClauses(parser, result, clauses, segments) || 1505 parser.parseColonType(addrType) || parser.parseComma() || 1506 parser.parseType(valueType) || 1507 parser.resolveOperand(address, addrType, result.operands) || 1508 parser.resolveOperand(value, valueType, result.operands)) 1509 return failure(); 1510 return success(); 1511 } 1512 1513 /// Printer for AtomicWriteOp 1514 static void printAtomicWriteOp(OpAsmPrinter &p, AtomicWriteOp op) { 1515 p << " " << op.address() << " = " << op.value() << " "; 1516 if (auto mo = op.memory_order()) 1517 p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") "; 1518 if (op.hintAttr()) 1519 printSynchronizationHint(p, op, op.hintAttr()); 1520 p << ": " << op.address().getType() << ", " << op.value().getType(); 1521 } 1522 1523 /// Verifier for AtomicWriteOp 1524 static LogicalResult verifyAtomicWriteOp(AtomicWriteOp op) { 1525 if (auto mo = op.memory_order()) { 1526 if (*mo == ClauseMemoryOrderKind::acq_rel || 1527 *mo == ClauseMemoryOrderKind::acquire) { 1528 return op.emitError( 1529 "memory-order must not be acq_rel or acquire for atomic writes"); 1530 } 1531 } 1532 return verifySynchronizationHint(op, op.hint()); 1533 } 1534 1535 //===----------------------------------------------------------------------===// 1536 // AtomicUpdateOp 1537 //===----------------------------------------------------------------------===// 1538 1539 /// Parser for AtomicUpdateOp 1540 /// 1541 /// operation ::= `omp.atomic.update` atomic-clause-list region 1542 static ParseResult parseAtomicUpdateOp(OpAsmParser &parser, 1543 OperationState &result) { 1544 SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause}; 1545 SmallVector<int> segments; 1546 OpAsmParser::OperandType x, y, z; 1547 Type xType, exprType; 1548 StringRef binOp; 1549 1550 // x = y `op` z : xtype, exprtype 1551 if (parser.parseOperand(x) || parser.parseEqual() || parser.parseOperand(y) || 1552 parser.parseKeyword(&binOp) || parser.parseOperand(z) || 1553 parseClauses(parser, result, clauses, segments) || parser.parseColon() || 1554 parser.parseType(xType) || parser.parseComma() || 1555 parser.parseType(exprType) || 1556 parser.resolveOperand(x, xType, result.operands)) { 1557 return failure(); 1558 } 1559 1560 auto binOpEnum = AtomicBinOpKindToEnum(binOp.upper()); 1561 if (!binOpEnum) 1562 return parser.emitError(parser.getNameLoc()) 1563 << "invalid atomic bin op in atomic update\n"; 1564 auto attr = 1565 parser.getBuilder().getI64IntegerAttr((int64_t)binOpEnum.getValue()); 1566 result.addAttribute("binop", attr); 1567 1568 OpAsmParser::OperandType expr; 1569 if (x.name == y.name && x.number == y.number) { 1570 expr = z; 1571 result.addAttribute("isXBinopExpr", parser.getBuilder().getUnitAttr()); 1572 } else if (x.name == z.name && x.number == z.number) { 1573 expr = y; 1574 } else { 1575 return parser.emitError(parser.getNameLoc()) 1576 << "atomic update variable " << x.name 1577 << " not found in the RHS of the assignment statement in an" 1578 " atomic.update operation"; 1579 } 1580 return parser.resolveOperand(expr, exprType, result.operands); 1581 } 1582 1583 /// Printer for AtomicUpdateOp 1584 static void printAtomicUpdateOp(OpAsmPrinter &p, AtomicUpdateOp op) { 1585 p << " " << op.x() << " = "; 1586 Value y, z; 1587 if (op.isXBinopExpr()) { 1588 y = op.x(); 1589 z = op.expr(); 1590 } else { 1591 y = op.expr(); 1592 z = op.x(); 1593 } 1594 p << y << " " << AtomicBinOpKindToString(op.binop()).lower() << " " << z 1595 << " "; 1596 if (auto mo = op.memory_order()) 1597 p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") "; 1598 if (op.hintAttr()) 1599 printSynchronizationHint(p, op, op.hintAttr()); 1600 p << ": " << op.x().getType() << ", " << op.expr().getType(); 1601 } 1602 1603 /// Verifier for AtomicUpdateOp 1604 static LogicalResult verifyAtomicUpdateOp(AtomicUpdateOp op) { 1605 if (auto mo = op.memory_order()) { 1606 if (*mo == ClauseMemoryOrderKind::acq_rel || 1607 *mo == ClauseMemoryOrderKind::acquire) { 1608 return op.emitError( 1609 "memory-order must not be acq_rel or acquire for atomic updates"); 1610 } 1611 } 1612 return success(); 1613 } 1614 1615 //===----------------------------------------------------------------------===// 1616 // AtomicCaptureOp 1617 //===----------------------------------------------------------------------===// 1618 1619 /// Parser for AtomicCaptureOp 1620 static LogicalResult parseAtomicCaptureOp(OpAsmParser &parser, 1621 OperationState &result) { 1622 SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause}; 1623 SmallVector<int> segments; 1624 if (parseClauses(parser, result, clauses, segments) || 1625 parser.parseRegion(*result.addRegion())) 1626 return failure(); 1627 return success(); 1628 } 1629 1630 /// Printer for AtomicCaptureOp 1631 static void printAtomicCaptureOp(OpAsmPrinter &p, AtomicCaptureOp op) { 1632 if (op.memory_order()) 1633 p << "memory_order(" << op.memory_order() << ") "; 1634 if (op.hintAttr()) 1635 printSynchronizationHint(p, op, op.hintAttr()); 1636 p.printRegion(op.region()); 1637 } 1638 1639 /// Verifier for AtomicCaptureOp 1640 static LogicalResult verifyAtomicCaptureOp(AtomicCaptureOp op) { 1641 Block::OpListType &ops = op.region().front().getOperations(); 1642 if (ops.size() != 3) 1643 return emitError(op.getLoc()) 1644 << "expected three operations in omp.atomic.capture region (one " 1645 "terminator, and two atomic ops)"; 1646 auto &firstOp = ops.front(); 1647 auto &secondOp = *ops.getNextNode(firstOp); 1648 auto firstReadStmt = dyn_cast<AtomicReadOp>(firstOp); 1649 auto firstUpdateStmt = dyn_cast<AtomicUpdateOp>(firstOp); 1650 auto secondReadStmt = dyn_cast<AtomicReadOp>(secondOp); 1651 auto secondUpdateStmt = dyn_cast<AtomicUpdateOp>(secondOp); 1652 auto secondWriteStmt = dyn_cast<AtomicWriteOp>(secondOp); 1653 1654 if (!((firstUpdateStmt && secondReadStmt) || 1655 (firstReadStmt && secondUpdateStmt) || 1656 (firstReadStmt && secondWriteStmt))) 1657 return emitError(ops.front().getLoc()) 1658 << "invalid sequence of operations in the capture region"; 1659 if (firstUpdateStmt && secondReadStmt && 1660 firstUpdateStmt.x() != secondReadStmt.x()) 1661 return emitError(firstUpdateStmt.getLoc()) 1662 << "updated variable in omp.atomic.update must be captured in " 1663 "second operation"; 1664 if (firstReadStmt && secondUpdateStmt && 1665 firstReadStmt.x() != secondUpdateStmt.x()) 1666 return emitError(firstReadStmt.getLoc()) 1667 << "captured variable in omp.atomic.read must be updated in second " 1668 "operation"; 1669 if (firstReadStmt && secondWriteStmt && 1670 firstReadStmt.x() != secondWriteStmt.address()) 1671 return emitError(firstReadStmt.getLoc()) 1672 << "captured variable in omp.atomic.read must be updated in " 1673 "second operation"; 1674 return success(); 1675 } 1676 1677 #define GET_ATTRDEF_CLASSES 1678 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" 1679 1680 #define GET_OP_CLASSES 1681 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 1682