1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the OpenMP dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 15 #include "mlir/IR/Attributes.h" 16 #include "mlir/IR/DialectImplementation.h" 17 #include "mlir/IR/OpImplementation.h" 18 #include "mlir/IR/OperationSupport.h" 19 20 #include "llvm/ADT/BitVector.h" 21 #include "llvm/ADT/SmallString.h" 22 #include "llvm/ADT/StringExtras.h" 23 #include "llvm/ADT/StringRef.h" 24 #include "llvm/ADT/StringSwitch.h" 25 #include "llvm/ADT/TypeSwitch.h" 26 #include <cstddef> 27 28 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc" 29 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc" 30 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc" 31 32 using namespace mlir; 33 using namespace mlir::omp; 34 35 namespace { 36 /// Model for pointer-like types that already provide a `getElementType` method. 37 template <typename T> 38 struct PointerLikeModel 39 : public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> { 40 Type getElementType(Type pointer) const { 41 return pointer.cast<T>().getElementType(); 42 } 43 }; 44 } // namespace 45 46 void OpenMPDialect::initialize() { 47 addOperations< 48 #define GET_OP_LIST 49 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 50 >(); 51 addAttributes< 52 #define GET_ATTRDEF_LIST 53 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" 54 >(); 55 56 LLVM::LLVMPointerType::attachInterface< 57 PointerLikeModel<LLVM::LLVMPointerType>>(*getContext()); 58 MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext()); 59 } 60 61 //===----------------------------------------------------------------------===// 62 // ParallelOp 63 //===----------------------------------------------------------------------===// 64 65 void ParallelOp::build(OpBuilder &builder, OperationState &state, 66 ArrayRef<NamedAttribute> attributes) { 67 ParallelOp::build( 68 builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr, 69 /*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(), 70 /*proc_bind_val=*/nullptr); 71 state.addAttributes(attributes); 72 } 73 74 //===----------------------------------------------------------------------===// 75 // Parser and printer for Allocate Clause 76 //===----------------------------------------------------------------------===// 77 78 /// Parse an allocate clause with allocators and a list of operands with types. 79 /// 80 /// allocate ::= `allocate` `(` allocate-operand-list `)` 81 /// allocate-operand-list :: = allocate-operand | 82 /// allocator-operand `,` allocate-operand-list 83 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type 84 /// ssa-id-and-type ::= ssa-id `:` type 85 static ParseResult parseAllocateAndAllocator( 86 OpAsmParser &parser, 87 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate, 88 SmallVectorImpl<Type> &typesAllocate, 89 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator, 90 SmallVectorImpl<Type> &typesAllocator) { 91 92 return parser.parseCommaSeparatedList( 93 OpAsmParser::Delimiter::Paren, [&]() -> ParseResult { 94 OpAsmParser::OperandType operand; 95 Type type; 96 if (parser.parseOperand(operand) || parser.parseColonType(type)) 97 return failure(); 98 operandsAllocator.push_back(operand); 99 typesAllocator.push_back(type); 100 if (parser.parseArrow()) 101 return failure(); 102 if (parser.parseOperand(operand) || parser.parseColonType(type)) 103 return failure(); 104 105 operandsAllocate.push_back(operand); 106 typesAllocate.push_back(type); 107 return success(); 108 }); 109 } 110 111 /// Print allocate clause 112 static void printAllocateAndAllocator(OpAsmPrinter &p, 113 OperandRange varsAllocate, 114 OperandRange varsAllocator) { 115 p << "allocate("; 116 for (unsigned i = 0; i < varsAllocate.size(); ++i) { 117 std::string separator = i == varsAllocate.size() - 1 ? ") " : ", "; 118 p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> "; 119 p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator; 120 } 121 } 122 123 LogicalResult ParallelOp::verify() { 124 if (allocate_vars().size() != allocators_vars().size()) 125 return emitError( 126 "expected equal sizes for allocate and allocator variables"); 127 return success(); 128 } 129 130 void ParallelOp::print(OpAsmPrinter &p) { 131 p << " "; 132 if (auto ifCond = if_expr_var()) 133 p << "if(" << ifCond << " : " << ifCond.getType() << ") "; 134 135 if (auto threads = num_threads_var()) 136 p << "num_threads(" << threads << " : " << threads.getType() << ") "; 137 138 if (!allocate_vars().empty()) 139 printAllocateAndAllocator(p, allocate_vars(), allocators_vars()); 140 141 if (auto bind = proc_bind_val()) 142 p << "proc_bind(" << stringifyClauseProcBindKind(*bind) << ") "; 143 144 p << ' '; 145 p.printRegion(getRegion()); 146 } 147 148 void TargetOp::print(OpAsmPrinter &p) { 149 p << " "; 150 if (auto ifCond = if_expr()) 151 p << "if(" << ifCond << " : " << ifCond.getType() << ") "; 152 153 if (auto device = this->device()) 154 p << "device(" << device << " : " << device.getType() << ") "; 155 156 if (auto threads = thread_limit()) 157 p << "thread_limit(" << threads << " : " << threads.getType() << ") "; 158 159 if (nowait()) 160 p << "nowait "; 161 162 p.printRegion(getRegion()); 163 } 164 165 //===----------------------------------------------------------------------===// 166 // Parser and printer for Linear Clause 167 //===----------------------------------------------------------------------===// 168 169 /// linear ::= `linear` `(` linear-list `)` 170 /// linear-list := linear-val | linear-val linear-list 171 /// linear-val := ssa-id-and-type `=` ssa-id-and-type 172 static ParseResult 173 parseLinearClause(OpAsmParser &parser, 174 SmallVectorImpl<OpAsmParser::OperandType> &vars, 175 SmallVectorImpl<Type> &types, 176 SmallVectorImpl<OpAsmParser::OperandType> &stepVars) { 177 if (parser.parseLParen()) 178 return failure(); 179 180 do { 181 OpAsmParser::OperandType var; 182 Type type; 183 OpAsmParser::OperandType stepVar; 184 if (parser.parseOperand(var) || parser.parseEqual() || 185 parser.parseOperand(stepVar) || parser.parseColonType(type)) 186 return failure(); 187 188 vars.push_back(var); 189 types.push_back(type); 190 stepVars.push_back(stepVar); 191 } while (succeeded(parser.parseOptionalComma())); 192 193 if (parser.parseRParen()) 194 return failure(); 195 196 return success(); 197 } 198 199 /// Print Linear Clause 200 static void printLinearClause(OpAsmPrinter &p, OperandRange linearVars, 201 OperandRange linearStepVars) { 202 size_t linearVarsSize = linearVars.size(); 203 p << "linear("; 204 for (unsigned i = 0; i < linearVarsSize; ++i) { 205 std::string separator = i == linearVarsSize - 1 ? ") " : ", "; 206 p << linearVars[i]; 207 if (linearStepVars.size() > i) 208 p << " = " << linearStepVars[i]; 209 p << " : " << linearVars[i].getType() << separator; 210 } 211 } 212 213 //===----------------------------------------------------------------------===// 214 // Parser and printer for Schedule Clause 215 //===----------------------------------------------------------------------===// 216 217 static ParseResult 218 verifyScheduleModifiers(OpAsmParser &parser, 219 SmallVectorImpl<SmallString<12>> &modifiers) { 220 if (modifiers.size() > 2) 221 return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)"; 222 for (const auto &mod : modifiers) { 223 // Translate the string. If it has no value, then it was not a valid 224 // modifier! 225 auto symbol = symbolizeScheduleModifier(mod); 226 if (!symbol.hasValue()) 227 return parser.emitError(parser.getNameLoc()) 228 << " unknown modifier type: " << mod; 229 } 230 231 // If we have one modifier that is "simd", then stick a "none" modiifer in 232 // index 0. 233 if (modifiers.size() == 1) { 234 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) { 235 modifiers.push_back(modifiers[0]); 236 modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none); 237 } 238 } else if (modifiers.size() == 2) { 239 // If there are two modifier: 240 // First modifier should not be simd, second one should be simd 241 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd || 242 symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd) 243 return parser.emitError(parser.getNameLoc()) 244 << " incorrect modifier order"; 245 } 246 return success(); 247 } 248 249 /// schedule ::= `schedule` `(` sched-list `)` 250 /// sched-list ::= sched-val | sched-val sched-list | 251 /// sched-val `,` sched-modifier 252 /// sched-val ::= sched-with-chunk | sched-wo-chunk 253 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)? 254 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided` 255 /// sched-wo-chunk ::= `auto` | `runtime` 256 /// sched-modifier ::= sched-mod-val | sched-mod-val `,` sched-mod-val 257 /// sched-mod-val ::= `monotonic` | `nonmonotonic` | `simd` | `none` 258 static ParseResult 259 parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule, 260 SmallVectorImpl<SmallString<12>> &modifiers, 261 Optional<OpAsmParser::OperandType> &chunkSize, 262 Type &chunkType) { 263 if (parser.parseLParen()) 264 return failure(); 265 266 StringRef keyword; 267 if (parser.parseKeyword(&keyword)) 268 return failure(); 269 270 schedule = keyword; 271 if (keyword == "static" || keyword == "dynamic" || keyword == "guided") { 272 if (succeeded(parser.parseOptionalEqual())) { 273 chunkSize = OpAsmParser::OperandType{}; 274 if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType)) 275 return failure(); 276 } else { 277 chunkSize = llvm::NoneType::None; 278 } 279 } else if (keyword == "auto" || keyword == "runtime") { 280 chunkSize = llvm::NoneType::None; 281 } else { 282 return parser.emitError(parser.getNameLoc()) << " expected schedule kind"; 283 } 284 285 // If there is a comma, we have one or more modifiers.. 286 while (succeeded(parser.parseOptionalComma())) { 287 StringRef mod; 288 if (parser.parseKeyword(&mod)) 289 return failure(); 290 modifiers.push_back(mod); 291 } 292 293 if (parser.parseRParen()) 294 return failure(); 295 296 if (verifyScheduleModifiers(parser, modifiers)) 297 return failure(); 298 299 return success(); 300 } 301 302 /// Print schedule clause 303 static void printScheduleClause(OpAsmPrinter &p, ClauseScheduleKind sched, 304 Optional<ScheduleModifier> modifier, bool simd, 305 Value scheduleChunkVar) { 306 p << "schedule(" << stringifyClauseScheduleKind(sched).lower(); 307 if (scheduleChunkVar) 308 p << " = " << scheduleChunkVar << " : " << scheduleChunkVar.getType(); 309 if (modifier) 310 p << ", " << stringifyScheduleModifier(*modifier); 311 if (simd) 312 p << ", simd"; 313 p << ") "; 314 } 315 316 //===----------------------------------------------------------------------===// 317 // Parser, printer and verifier for ReductionVarList 318 //===----------------------------------------------------------------------===// 319 320 /// reduction ::= `reduction` `(` reduction-entry-list `)` 321 /// reduction-entry-list ::= reduction-entry 322 /// | reduction-entry-list `,` reduction-entry 323 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type 324 static ParseResult 325 parseReductionVarList(OpAsmParser &parser, 326 SmallVectorImpl<SymbolRefAttr> &symbols, 327 SmallVectorImpl<OpAsmParser::OperandType> &operands, 328 SmallVectorImpl<Type> &types) { 329 if (failed(parser.parseLParen())) 330 return failure(); 331 332 do { 333 if (parser.parseAttribute(symbols.emplace_back()) || parser.parseArrow() || 334 parser.parseOperand(operands.emplace_back()) || 335 parser.parseColonType(types.emplace_back())) 336 return failure(); 337 } while (succeeded(parser.parseOptionalComma())); 338 return parser.parseRParen(); 339 } 340 341 /// Print Reduction clause 342 static void printReductionVarList(OpAsmPrinter &p, 343 Optional<ArrayAttr> reductions, 344 OperandRange reductionVars) { 345 p << "reduction("; 346 for (unsigned i = 0, e = reductions->size(); i < e; ++i) { 347 if (i != 0) 348 p << ", "; 349 p << (*reductions)[i] << " -> " << reductionVars[i] << " : " 350 << reductionVars[i].getType(); 351 } 352 p << ") "; 353 } 354 355 /// Verifies Reduction Clause 356 static LogicalResult verifyReductionVarList(Operation *op, 357 Optional<ArrayAttr> reductions, 358 OperandRange reductionVars) { 359 if (!reductionVars.empty()) { 360 if (!reductions || reductions->size() != reductionVars.size()) 361 return op->emitOpError() 362 << "expected as many reduction symbol references " 363 "as reduction variables"; 364 } else { 365 if (reductions) 366 return op->emitOpError() << "unexpected reduction symbol references"; 367 return success(); 368 } 369 370 DenseSet<Value> accumulators; 371 for (auto args : llvm::zip(reductionVars, *reductions)) { 372 Value accum = std::get<0>(args); 373 374 if (!accumulators.insert(accum).second) 375 return op->emitOpError() << "accumulator variable used more than once"; 376 377 Type varType = accum.getType().cast<PointerLikeType>(); 378 auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>(); 379 auto decl = 380 SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef); 381 if (!decl) 382 return op->emitOpError() << "expected symbol reference " << symbolRef 383 << " to point to a reduction declaration"; 384 385 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType) 386 return op->emitOpError() 387 << "expected accumulator (" << varType 388 << ") to be the same type as reduction declaration (" 389 << decl.getAccumulatorType() << ")"; 390 } 391 392 return success(); 393 } 394 395 //===----------------------------------------------------------------------===// 396 // Parser, printer and verifier for Synchronization Hint (2.17.12) 397 //===----------------------------------------------------------------------===// 398 399 /// Parses a Synchronization Hint clause. The value of hint is an integer 400 /// which is a combination of different hints from `omp_sync_hint_t`. 401 /// 402 /// hint-clause = `hint` `(` hint-value `)` 403 static ParseResult parseSynchronizationHint(OpAsmParser &parser, 404 IntegerAttr &hintAttr, 405 bool parseKeyword = true) { 406 if (parseKeyword && failed(parser.parseOptionalKeyword("hint"))) { 407 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0); 408 return success(); 409 } 410 411 if (failed(parser.parseLParen())) 412 return failure(); 413 StringRef hintKeyword; 414 int64_t hint = 0; 415 do { 416 if (failed(parser.parseKeyword(&hintKeyword))) 417 return failure(); 418 if (hintKeyword == "uncontended") 419 hint |= 1; 420 else if (hintKeyword == "contended") 421 hint |= 2; 422 else if (hintKeyword == "nonspeculative") 423 hint |= 4; 424 else if (hintKeyword == "speculative") 425 hint |= 8; 426 else 427 return parser.emitError(parser.getCurrentLocation()) 428 << hintKeyword << " is not a valid hint"; 429 } while (succeeded(parser.parseOptionalComma())); 430 if (failed(parser.parseRParen())) 431 return failure(); 432 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint); 433 return success(); 434 } 435 436 /// Prints a Synchronization Hint clause 437 static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, 438 IntegerAttr hintAttr) { 439 int64_t hint = hintAttr.getInt(); 440 441 if (hint == 0) 442 return; 443 444 // Helper function to get n-th bit from the right end of `value` 445 auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; 446 447 bool uncontended = bitn(hint, 0); 448 bool contended = bitn(hint, 1); 449 bool nonspeculative = bitn(hint, 2); 450 bool speculative = bitn(hint, 3); 451 452 SmallVector<StringRef> hints; 453 if (uncontended) 454 hints.push_back("uncontended"); 455 if (contended) 456 hints.push_back("contended"); 457 if (nonspeculative) 458 hints.push_back("nonspeculative"); 459 if (speculative) 460 hints.push_back("speculative"); 461 462 p << "hint("; 463 llvm::interleaveComma(hints, p); 464 p << ") "; 465 } 466 467 /// Verifies a synchronization hint clause 468 static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) { 469 470 // Helper function to get n-th bit from the right end of `value` 471 auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; 472 473 bool uncontended = bitn(hint, 0); 474 bool contended = bitn(hint, 1); 475 bool nonspeculative = bitn(hint, 2); 476 bool speculative = bitn(hint, 3); 477 478 if (uncontended && contended) 479 return op->emitOpError() << "the hints omp_sync_hint_uncontended and " 480 "omp_sync_hint_contended cannot be combined"; 481 if (nonspeculative && speculative) 482 return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and " 483 "omp_sync_hint_speculative cannot be combined."; 484 return success(); 485 } 486 487 enum ClauseType { 488 ifClause, 489 numThreadsClause, 490 deviceClause, 491 threadLimitClause, 492 allocateClause, 493 procBindClause, 494 reductionClause, 495 nowaitClause, 496 linearClause, 497 scheduleClause, 498 collapseClause, 499 orderClause, 500 orderedClause, 501 memoryOrderClause, 502 hintClause, 503 COUNT 504 }; 505 506 //===----------------------------------------------------------------------===// 507 // Parser for Clause List 508 //===----------------------------------------------------------------------===// 509 510 /// Parse a clause attribute `(` $value `)`. 511 template <typename ClauseAttr> 512 static ParseResult parseClauseAttr(AsmParser &parser, OperationState &state, 513 StringRef attrName, StringRef name) { 514 using ClauseT = decltype(std::declval<ClauseAttr>().getValue()); 515 StringRef enumStr; 516 SMLoc loc = parser.getCurrentLocation(); 517 if (parser.parseLParen() || parser.parseKeyword(&enumStr) || 518 parser.parseRParen()) 519 return failure(); 520 if (Optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) { 521 auto attr = ClauseAttr::get(parser.getContext(), *enumValue); 522 state.addAttribute(attrName, attr); 523 return success(); 524 } 525 return parser.emitError(loc, "invalid ") << name << " kind"; 526 } 527 528 /// Parse a list of clauses. The clauses can appear in any order, but their 529 /// operand segment indices are in the same order that they are passed in the 530 /// `clauses` list. The operand segments are added over the prevSegments 531 532 /// clause-list ::= clause clause-list | empty 533 /// clause ::= if | num-threads | allocate | proc-bind | reduction | nowait 534 /// | linear | schedule | collapse | order | ordered | inclusive 535 /// if ::= `if` `(` ssa-id-and-type `)` 536 /// num-threads ::= `num_threads` `(` ssa-id-and-type `)` 537 /// allocate ::= `allocate` `(` allocate-operand-list `)` 538 /// proc-bind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)` 539 /// reduction ::= `reduction` `(` reduction-entry-list `)` 540 /// nowait ::= `nowait` 541 /// linear ::= `linear` `(` linear-list `)` 542 /// schedule ::= `schedule` `(` sched-list `)` 543 /// collapse ::= `collapse` `(` ssa-id-and-type `)` 544 /// order ::= `order` `(` `concurrent` `)` 545 /// ordered ::= `ordered` `(` ssa-id-and-type `)` 546 /// inclusive ::= `inclusive` 547 /// 548 /// Note that each clause can only appear once in the clase-list. 549 static ParseResult parseClauses(OpAsmParser &parser, OperationState &result, 550 SmallVectorImpl<ClauseType> &clauses, 551 SmallVectorImpl<int> &segments) { 552 553 // Check done[clause] to see if it has been parsed already 554 BitVector done(ClauseType::COUNT, false); 555 556 // See pos[clause] to get position of clause in operand segments 557 SmallVector<int> pos(ClauseType::COUNT, -1); 558 559 // Stores the last parsed clause keyword 560 StringRef clauseKeyword; 561 StringRef opName = result.name.getStringRef(); 562 563 // Containers for storing operands, types and attributes for various clauses 564 std::pair<OpAsmParser::OperandType, Type> ifCond; 565 std::pair<OpAsmParser::OperandType, Type> numThreads; 566 std::pair<OpAsmParser::OperandType, Type> device; 567 std::pair<OpAsmParser::OperandType, Type> threadLimit; 568 569 SmallVector<OpAsmParser::OperandType> allocates, allocators; 570 SmallVector<Type> allocateTypes, allocatorTypes; 571 572 SmallVector<SymbolRefAttr> reductionSymbols; 573 SmallVector<OpAsmParser::OperandType> reductionVars; 574 SmallVector<Type> reductionVarTypes; 575 576 SmallVector<OpAsmParser::OperandType> linears; 577 SmallVector<Type> linearTypes; 578 SmallVector<OpAsmParser::OperandType> linearSteps; 579 580 SmallString<8> schedule; 581 SmallVector<SmallString<12>> modifiers; 582 Optional<OpAsmParser::OperandType> scheduleChunkSize; 583 Type scheduleChunkType; 584 585 // Compute the position of clauses in operand segments 586 int currPos = 0; 587 for (ClauseType clause : clauses) { 588 589 // Skip the following clauses - they do not take any position in operand 590 // segments 591 if (clause == procBindClause || clause == nowaitClause || 592 clause == collapseClause || clause == orderClause || 593 clause == orderedClause) 594 continue; 595 596 pos[clause] = currPos++; 597 598 // For the following clauses, two positions are reserved in the operand 599 // segments 600 if (clause == allocateClause || clause == linearClause) 601 currPos++; 602 } 603 604 SmallVector<int> clauseSegments(currPos); 605 606 // Helper function to check if a clause is allowed/repeated or not 607 auto checkAllowed = [&](ClauseType clause) -> ParseResult { 608 if (!llvm::is_contained(clauses, clause)) 609 return parser.emitError(parser.getCurrentLocation()) 610 << clauseKeyword << " is not a valid clause for the " << opName 611 << " operation"; 612 if (done[clause]) 613 return parser.emitError(parser.getCurrentLocation()) 614 << "at most one " << clauseKeyword << " clause can appear on the " 615 << opName << " operation"; 616 done[clause] = true; 617 return success(); 618 }; 619 620 while (succeeded(parser.parseOptionalKeyword(&clauseKeyword))) { 621 if (clauseKeyword == "if") { 622 if (checkAllowed(ifClause) || parser.parseLParen() || 623 parser.parseOperand(ifCond.first) || 624 parser.parseColonType(ifCond.second) || parser.parseRParen()) 625 return failure(); 626 clauseSegments[pos[ifClause]] = 1; 627 } else if (clauseKeyword == "num_threads") { 628 if (checkAllowed(numThreadsClause) || parser.parseLParen() || 629 parser.parseOperand(numThreads.first) || 630 parser.parseColonType(numThreads.second) || parser.parseRParen()) 631 return failure(); 632 clauseSegments[pos[numThreadsClause]] = 1; 633 } else if (clauseKeyword == "device") { 634 if (checkAllowed(deviceClause) || parser.parseLParen() || 635 parser.parseOperand(device.first) || 636 parser.parseColonType(device.second) || parser.parseRParen()) 637 return failure(); 638 clauseSegments[pos[deviceClause]] = 1; 639 } else if (clauseKeyword == "thread_limit") { 640 if (checkAllowed(threadLimitClause) || parser.parseLParen() || 641 parser.parseOperand(threadLimit.first) || 642 parser.parseColonType(threadLimit.second) || parser.parseRParen()) 643 return failure(); 644 clauseSegments[pos[threadLimitClause]] = 1; 645 } else if (clauseKeyword == "allocate") { 646 if (checkAllowed(allocateClause) || 647 parseAllocateAndAllocator(parser, allocates, allocateTypes, 648 allocators, allocatorTypes)) 649 return failure(); 650 clauseSegments[pos[allocateClause]] = allocates.size(); 651 clauseSegments[pos[allocateClause] + 1] = allocators.size(); 652 } else if (clauseKeyword == "proc_bind") { 653 if (checkAllowed(procBindClause) || 654 parseClauseAttr<ClauseProcBindKindAttr>(parser, result, 655 "proc_bind_val", "proc bind")) 656 return failure(); 657 } else if (clauseKeyword == "reduction") { 658 if (checkAllowed(reductionClause) || 659 parseReductionVarList(parser, reductionSymbols, reductionVars, 660 reductionVarTypes)) 661 return failure(); 662 clauseSegments[pos[reductionClause]] = reductionVars.size(); 663 } else if (clauseKeyword == "nowait") { 664 if (checkAllowed(nowaitClause)) 665 return failure(); 666 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 667 result.addAttribute("nowait", attr); 668 } else if (clauseKeyword == "linear") { 669 if (checkAllowed(linearClause) || 670 parseLinearClause(parser, linears, linearTypes, linearSteps)) 671 return failure(); 672 clauseSegments[pos[linearClause]] = linears.size(); 673 clauseSegments[pos[linearClause] + 1] = linearSteps.size(); 674 } else if (clauseKeyword == "schedule") { 675 if (checkAllowed(scheduleClause) || 676 parseScheduleClause(parser, schedule, modifiers, scheduleChunkSize, 677 scheduleChunkType)) 678 return failure(); 679 if (scheduleChunkSize) { 680 clauseSegments[pos[scheduleClause]] = 1; 681 } 682 } else if (clauseKeyword == "collapse") { 683 auto type = parser.getBuilder().getI64Type(); 684 mlir::IntegerAttr attr; 685 if (checkAllowed(collapseClause) || parser.parseLParen() || 686 parser.parseAttribute(attr, type) || parser.parseRParen()) 687 return failure(); 688 result.addAttribute("collapse_val", attr); 689 } else if (clauseKeyword == "ordered") { 690 mlir::IntegerAttr attr; 691 if (checkAllowed(orderedClause)) 692 return failure(); 693 if (succeeded(parser.parseOptionalLParen())) { 694 auto type = parser.getBuilder().getI64Type(); 695 if (parser.parseAttribute(attr, type) || parser.parseRParen()) 696 return failure(); 697 } else { 698 // Use 0 to represent no ordered parameter was specified 699 attr = parser.getBuilder().getI64IntegerAttr(0); 700 } 701 result.addAttribute("ordered_val", attr); 702 } else if (clauseKeyword == "order") { 703 if (checkAllowed(orderClause) || 704 parseClauseAttr<ClauseOrderKindAttr>(parser, result, "order_val", 705 "order")) 706 return failure(); 707 } else if (clauseKeyword == "memory_order") { 708 if (checkAllowed(memoryOrderClause) || 709 parseClauseAttr<ClauseMemoryOrderKindAttr>( 710 parser, result, "memory_order", "memory order")) 711 return failure(); 712 } else if (clauseKeyword == "hint") { 713 IntegerAttr hint; 714 if (checkAllowed(hintClause) || 715 parseSynchronizationHint(parser, hint, false)) 716 return failure(); 717 result.addAttribute("hint", hint); 718 } else { 719 return parser.emitError(parser.getNameLoc()) 720 << clauseKeyword << " is not a valid clause"; 721 } 722 } 723 724 // Add if parameter. 725 if (done[ifClause] && clauseSegments[pos[ifClause]] && 726 failed( 727 parser.resolveOperand(ifCond.first, ifCond.second, result.operands))) 728 return failure(); 729 730 // Add num_threads parameter. 731 if (done[numThreadsClause] && clauseSegments[pos[numThreadsClause]] && 732 failed(parser.resolveOperand(numThreads.first, numThreads.second, 733 result.operands))) 734 return failure(); 735 736 // Add device parameter. 737 if (done[deviceClause] && clauseSegments[pos[deviceClause]] && 738 failed( 739 parser.resolveOperand(device.first, device.second, result.operands))) 740 return failure(); 741 742 // Add thread_limit parameter. 743 if (done[threadLimitClause] && clauseSegments[pos[threadLimitClause]] && 744 failed(parser.resolveOperand(threadLimit.first, threadLimit.second, 745 result.operands))) 746 return failure(); 747 748 // Add allocate parameters. 749 if (done[allocateClause] && clauseSegments[pos[allocateClause]] && 750 failed(parser.resolveOperands(allocates, allocateTypes, 751 allocates[0].location, result.operands))) 752 return failure(); 753 754 // Add allocator parameters. 755 if (done[allocateClause] && clauseSegments[pos[allocateClause] + 1] && 756 failed(parser.resolveOperands(allocators, allocatorTypes, 757 allocators[0].location, result.operands))) 758 return failure(); 759 760 // Add reduction parameters and symbols 761 if (done[reductionClause] && clauseSegments[pos[reductionClause]]) { 762 if (failed(parser.resolveOperands(reductionVars, reductionVarTypes, 763 parser.getNameLoc(), result.operands))) 764 return failure(); 765 766 SmallVector<Attribute> reductions(reductionSymbols.begin(), 767 reductionSymbols.end()); 768 result.addAttribute("reductions", 769 parser.getBuilder().getArrayAttr(reductions)); 770 } 771 772 // Add linear parameters 773 if (done[linearClause] && clauseSegments[pos[linearClause]]) { 774 auto linearStepType = parser.getBuilder().getI32Type(); 775 SmallVector<Type> linearStepTypes(linearSteps.size(), linearStepType); 776 if (failed(parser.resolveOperands(linears, linearTypes, linears[0].location, 777 result.operands)) || 778 failed(parser.resolveOperands(linearSteps, linearStepTypes, 779 linearSteps[0].location, 780 result.operands))) 781 return failure(); 782 } 783 784 // Add schedule parameters 785 if (done[scheduleClause] && !schedule.empty()) { 786 schedule[0] = llvm::toUpper(schedule[0]); 787 if (Optional<ClauseScheduleKind> sched = 788 symbolizeClauseScheduleKind(schedule)) { 789 auto attr = ClauseScheduleKindAttr::get(parser.getContext(), *sched); 790 result.addAttribute("schedule_val", attr); 791 } else { 792 return parser.emitError(parser.getCurrentLocation(), 793 "invalid schedule kind"); 794 } 795 if (!modifiers.empty()) { 796 SMLoc loc = parser.getCurrentLocation(); 797 if (Optional<ScheduleModifier> mod = 798 symbolizeScheduleModifier(modifiers[0])) { 799 result.addAttribute( 800 "schedule_modifier", 801 ScheduleModifierAttr::get(parser.getContext(), *mod)); 802 } else { 803 return parser.emitError(loc, "invalid schedule modifier"); 804 } 805 // Only SIMD attribute is allowed here! 806 if (modifiers.size() > 1) { 807 assert(symbolizeScheduleModifier(modifiers[1]) == 808 ScheduleModifier::simd); 809 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 810 result.addAttribute("simd_modifier", attr); 811 } 812 } 813 if (scheduleChunkSize) 814 parser.resolveOperand(*scheduleChunkSize, scheduleChunkType, 815 result.operands); 816 } 817 818 segments.insert(segments.end(), clauseSegments.begin(), clauseSegments.end()); 819 820 return success(); 821 } 822 823 /// Parses a parallel operation. 824 /// 825 /// operation ::= `omp.parallel` clause-list 826 /// clause-list ::= clause | clause clause-list 827 /// clause ::= if | num-threads | allocate | proc-bind 828 /// 829 ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) { 830 SmallVector<ClauseType> clauses = {ifClause, numThreadsClause, allocateClause, 831 procBindClause}; 832 833 SmallVector<int> segments; 834 835 if (failed(parseClauses(parser, result, clauses, segments))) 836 return failure(); 837 838 result.addAttribute("operand_segment_sizes", 839 parser.getBuilder().getI32VectorAttr(segments)); 840 841 Region *body = result.addRegion(); 842 SmallVector<OpAsmParser::OperandType> regionArgs; 843 SmallVector<Type> regionArgTypes; 844 if (parser.parseRegion(*body, regionArgs, regionArgTypes)) 845 return failure(); 846 return success(); 847 } 848 849 /// Parses a target operation. 850 /// 851 /// operation ::= `omp.target` clause-list 852 /// clause-list ::= clause | clause clause-list 853 /// clause ::= if | device | thread_limit | nowait 854 /// 855 ParseResult TargetOp::parse(OpAsmParser &parser, OperationState &result) { 856 SmallVector<ClauseType> clauses = {ifClause, deviceClause, threadLimitClause, 857 nowaitClause}; 858 859 SmallVector<int> segments; 860 861 if (failed(parseClauses(parser, result, clauses, segments))) 862 return failure(); 863 864 result.addAttribute( 865 TargetOp::AttrSizedOperandSegments::getOperandSegmentSizeAttr(), 866 parser.getBuilder().getI32VectorAttr(segments)); 867 868 Region *body = result.addRegion(); 869 SmallVector<OpAsmParser::OperandType> regionArgs; 870 SmallVector<Type> regionArgTypes; 871 if (parser.parseRegion(*body, regionArgs, regionArgTypes)) 872 return failure(); 873 return success(); 874 } 875 876 //===----------------------------------------------------------------------===// 877 // Parser, printer and verifier for SectionsOp 878 //===----------------------------------------------------------------------===// 879 880 /// Parses an OpenMP Sections operation 881 /// 882 /// sections ::= `omp.sections` clause-list 883 /// clause-list ::= clause clause-list | empty 884 /// clause ::= reduction | allocate | nowait 885 ParseResult SectionsOp::parse(OpAsmParser &parser, OperationState &result) { 886 SmallVector<ClauseType> clauses = {reductionClause, allocateClause, 887 nowaitClause}; 888 889 SmallVector<int> segments; 890 891 if (failed(parseClauses(parser, result, clauses, segments))) 892 return failure(); 893 894 result.addAttribute("operand_segment_sizes", 895 parser.getBuilder().getI32VectorAttr(segments)); 896 897 // Now parse the body. 898 Region *body = result.addRegion(); 899 if (parser.parseRegion(*body)) 900 return failure(); 901 return success(); 902 } 903 904 void SectionsOp::print(OpAsmPrinter &p) { 905 p << " "; 906 907 if (!reduction_vars().empty()) 908 printReductionVarList(p, reductions(), reduction_vars()); 909 910 if (!allocate_vars().empty()) 911 printAllocateAndAllocator(p, allocate_vars(), allocators_vars()); 912 913 if (nowait()) 914 p << "nowait"; 915 916 p << ' '; 917 p.printRegion(region()); 918 } 919 920 LogicalResult SectionsOp::verify() { 921 if (allocate_vars().size() != allocators_vars().size()) 922 return emitError( 923 "expected equal sizes for allocate and allocator variables"); 924 925 for (auto &inst : *region().begin()) { 926 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) { 927 return emitOpError() 928 << "expected omp.section op or terminator op inside region"; 929 } 930 } 931 932 return verifyReductionVarList(*this, reductions(), reduction_vars()); 933 } 934 935 /// Parses an OpenMP Workshare Loop operation 936 /// 937 /// wsloop ::= `omp.wsloop` loop-control clause-list 938 /// loop-control ::= `(` ssa-id-list `)` `:` type `=` loop-bounds 939 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` inclusive? steps 940 /// steps := `step` `(`ssa-id-list`)` 941 /// clause-list ::= clause clause-list | empty 942 /// clause ::= linear | schedule | collapse | nowait | ordered | order 943 /// | reduction 944 ParseResult WsLoopOp::parse(OpAsmParser &parser, OperationState &result) { 945 // Parse an opening `(` followed by induction variables followed by `)` 946 SmallVector<OpAsmParser::OperandType> ivs; 947 if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1, 948 OpAsmParser::Delimiter::Paren)) 949 return failure(); 950 951 int numIVs = static_cast<int>(ivs.size()); 952 Type loopVarType; 953 if (parser.parseColonType(loopVarType)) 954 return failure(); 955 956 // Parse loop bounds. 957 SmallVector<OpAsmParser::OperandType> lower; 958 if (parser.parseEqual() || 959 parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) || 960 parser.resolveOperands(lower, loopVarType, result.operands)) 961 return failure(); 962 963 SmallVector<OpAsmParser::OperandType> upper; 964 if (parser.parseKeyword("to") || 965 parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) || 966 parser.resolveOperands(upper, loopVarType, result.operands)) 967 return failure(); 968 969 if (succeeded(parser.parseOptionalKeyword("inclusive"))) { 970 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 971 result.addAttribute("inclusive", attr); 972 } 973 974 // Parse step values. 975 SmallVector<OpAsmParser::OperandType> steps; 976 if (parser.parseKeyword("step") || 977 parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) || 978 parser.resolveOperands(steps, loopVarType, result.operands)) 979 return failure(); 980 981 SmallVector<ClauseType> clauses = { 982 linearClause, reductionClause, collapseClause, orderClause, 983 orderedClause, nowaitClause, scheduleClause}; 984 SmallVector<int> segments{numIVs, numIVs, numIVs}; 985 if (failed(parseClauses(parser, result, clauses, segments))) 986 return failure(); 987 988 result.addAttribute("operand_segment_sizes", 989 parser.getBuilder().getI32VectorAttr(segments)); 990 991 // Now parse the body. 992 Region *body = result.addRegion(); 993 SmallVector<Type> ivTypes(numIVs, loopVarType); 994 SmallVector<OpAsmParser::OperandType> blockArgs(ivs); 995 if (parser.parseRegion(*body, blockArgs, ivTypes)) 996 return failure(); 997 return success(); 998 } 999 1000 void WsLoopOp::print(OpAsmPrinter &p) { 1001 auto args = getRegion().front().getArguments(); 1002 p << " (" << args << ") : " << args[0].getType() << " = (" << lowerBound() 1003 << ") to (" << upperBound() << ") "; 1004 if (inclusive()) { 1005 p << "inclusive "; 1006 } 1007 p << "step (" << step() << ") "; 1008 1009 if (!linear_vars().empty()) 1010 printLinearClause(p, linear_vars(), linear_step_vars()); 1011 1012 if (auto sched = schedule_val()) 1013 printScheduleClause(p, sched.getValue(), schedule_modifier(), 1014 simd_modifier(), schedule_chunk_var()); 1015 1016 if (auto collapse = collapse_val()) 1017 p << "collapse(" << collapse << ") "; 1018 1019 if (nowait()) 1020 p << "nowait "; 1021 1022 if (auto ordered = ordered_val()) 1023 p << "ordered(" << ordered << ") "; 1024 1025 if (auto order = order_val()) 1026 p << "order(" << stringifyClauseOrderKind(*order) << ") "; 1027 1028 if (!reduction_vars().empty()) 1029 printReductionVarList(p, reductions(), reduction_vars()); 1030 1031 p << ' '; 1032 p.printRegion(region(), /*printEntryBlockArgs=*/false); 1033 } 1034 1035 //===----------------------------------------------------------------------===// 1036 // ReductionOp 1037 //===----------------------------------------------------------------------===// 1038 1039 static ParseResult parseAtomicReductionRegion(OpAsmParser &parser, 1040 Region ®ion) { 1041 if (parser.parseOptionalKeyword("atomic")) 1042 return success(); 1043 return parser.parseRegion(region); 1044 } 1045 1046 static void printAtomicReductionRegion(OpAsmPrinter &printer, 1047 ReductionDeclareOp op, Region ®ion) { 1048 if (region.empty()) 1049 return; 1050 printer << "atomic "; 1051 printer.printRegion(region); 1052 } 1053 1054 LogicalResult ReductionDeclareOp::verify() { 1055 if (initializerRegion().empty()) 1056 return emitOpError() << "expects non-empty initializer region"; 1057 Block &initializerEntryBlock = initializerRegion().front(); 1058 if (initializerEntryBlock.getNumArguments() != 1 || 1059 initializerEntryBlock.getArgument(0).getType() != type()) { 1060 return emitOpError() << "expects initializer region with one argument " 1061 "of the reduction type"; 1062 } 1063 1064 for (YieldOp yieldOp : initializerRegion().getOps<YieldOp>()) { 1065 if (yieldOp.results().size() != 1 || 1066 yieldOp.results().getTypes()[0] != type()) 1067 return emitOpError() << "expects initializer region to yield a value " 1068 "of the reduction type"; 1069 } 1070 1071 if (reductionRegion().empty()) 1072 return emitOpError() << "expects non-empty reduction region"; 1073 Block &reductionEntryBlock = reductionRegion().front(); 1074 if (reductionEntryBlock.getNumArguments() != 2 || 1075 reductionEntryBlock.getArgumentTypes()[0] != 1076 reductionEntryBlock.getArgumentTypes()[1] || 1077 reductionEntryBlock.getArgumentTypes()[0] != type()) 1078 return emitOpError() << "expects reduction region with two arguments of " 1079 "the reduction type"; 1080 for (YieldOp yieldOp : reductionRegion().getOps<YieldOp>()) { 1081 if (yieldOp.results().size() != 1 || 1082 yieldOp.results().getTypes()[0] != type()) 1083 return emitOpError() << "expects reduction region to yield a value " 1084 "of the reduction type"; 1085 } 1086 1087 if (atomicReductionRegion().empty()) 1088 return success(); 1089 1090 Block &atomicReductionEntryBlock = atomicReductionRegion().front(); 1091 if (atomicReductionEntryBlock.getNumArguments() != 2 || 1092 atomicReductionEntryBlock.getArgumentTypes()[0] != 1093 atomicReductionEntryBlock.getArgumentTypes()[1]) 1094 return emitOpError() << "expects atomic reduction region with two " 1095 "arguments of the same type"; 1096 auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0] 1097 .dyn_cast<PointerLikeType>(); 1098 if (!ptrType || ptrType.getElementType() != type()) 1099 return emitOpError() << "expects atomic reduction region arguments to " 1100 "be accumulators containing the reduction type"; 1101 return success(); 1102 } 1103 1104 LogicalResult ReductionOp::verify() { 1105 // TODO: generalize this to an op interface when there is more than one op 1106 // that supports reductions. 1107 auto container = (*this)->getParentOfType<WsLoopOp>(); 1108 for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i) 1109 if (container.reduction_vars()[i] == accumulator()) 1110 return success(); 1111 1112 return emitOpError() << "the accumulator is not used by the parent"; 1113 } 1114 1115 //===----------------------------------------------------------------------===// 1116 // WsLoopOp 1117 //===----------------------------------------------------------------------===// 1118 1119 void WsLoopOp::build(OpBuilder &builder, OperationState &state, 1120 ValueRange lowerBound, ValueRange upperBound, 1121 ValueRange step, ArrayRef<NamedAttribute> attributes) { 1122 build(builder, state, TypeRange(), lowerBound, upperBound, step, 1123 /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(), 1124 /*reduction_vars=*/ValueRange(), /*schedule_val=*/nullptr, 1125 /*schedule_chunk_var=*/nullptr, /*collapse_val=*/nullptr, 1126 /*nowait=*/nullptr, /*ordered_val=*/nullptr, /*order_val=*/nullptr, 1127 /*inclusive=*/nullptr, /*buildBody=*/false); 1128 state.addAttributes(attributes); 1129 } 1130 1131 void WsLoopOp::build(OpBuilder &, OperationState &state, TypeRange resultTypes, 1132 ValueRange operands, ArrayRef<NamedAttribute> attributes) { 1133 state.addOperands(operands); 1134 state.addAttributes(attributes); 1135 (void)state.addRegion(); 1136 assert(resultTypes.empty() && "mismatched number of return types"); 1137 state.addTypes(resultTypes); 1138 } 1139 1140 void WsLoopOp::build(OpBuilder &builder, OperationState &result, 1141 TypeRange typeRange, ValueRange lowerBounds, 1142 ValueRange upperBounds, ValueRange steps, 1143 ValueRange linearVars, ValueRange linearStepVars, 1144 ValueRange reductionVars, StringAttr scheduleVal, 1145 Value scheduleChunkVar, IntegerAttr collapseVal, 1146 UnitAttr nowait, IntegerAttr orderedVal, 1147 StringAttr orderVal, UnitAttr inclusive, bool buildBody) { 1148 result.addOperands(lowerBounds); 1149 result.addOperands(upperBounds); 1150 result.addOperands(steps); 1151 result.addOperands(linearVars); 1152 result.addOperands(linearStepVars); 1153 if (scheduleChunkVar) 1154 result.addOperands(scheduleChunkVar); 1155 1156 if (scheduleVal) 1157 result.addAttribute("schedule_val", scheduleVal); 1158 if (collapseVal) 1159 result.addAttribute("collapse_val", collapseVal); 1160 if (nowait) 1161 result.addAttribute("nowait", nowait); 1162 if (orderedVal) 1163 result.addAttribute("ordered_val", orderedVal); 1164 if (orderVal) 1165 result.addAttribute("order", orderVal); 1166 if (inclusive) 1167 result.addAttribute("inclusive", inclusive); 1168 result.addAttribute( 1169 WsLoopOp::getOperandSegmentSizeAttr(), 1170 builder.getI32VectorAttr( 1171 {static_cast<int32_t>(lowerBounds.size()), 1172 static_cast<int32_t>(upperBounds.size()), 1173 static_cast<int32_t>(steps.size()), 1174 static_cast<int32_t>(linearVars.size()), 1175 static_cast<int32_t>(linearStepVars.size()), 1176 static_cast<int32_t>(reductionVars.size()), 1177 static_cast<int32_t>(scheduleChunkVar != nullptr ? 1 : 0)})); 1178 1179 Region *bodyRegion = result.addRegion(); 1180 if (buildBody) { 1181 OpBuilder::InsertionGuard guard(builder); 1182 unsigned numIVs = steps.size(); 1183 SmallVector<Type, 8> argTypes(numIVs, steps.getType().front()); 1184 SmallVector<Location, 8> argLocs(numIVs, result.location); 1185 builder.createBlock(bodyRegion, {}, argTypes, argLocs); 1186 } 1187 } 1188 1189 LogicalResult WsLoopOp::verify() { 1190 return verifyReductionVarList(*this, reductions(), reduction_vars()); 1191 } 1192 1193 //===----------------------------------------------------------------------===// 1194 // Verifier for critical construct (2.17.1) 1195 //===----------------------------------------------------------------------===// 1196 1197 LogicalResult CriticalDeclareOp::verify() { 1198 return verifySynchronizationHint(*this, hint()); 1199 } 1200 1201 LogicalResult CriticalOp::verify() { 1202 if (nameAttr()) { 1203 SymbolRefAttr symbolRef = nameAttr(); 1204 auto decl = SymbolTable::lookupNearestSymbolFrom<CriticalDeclareOp>( 1205 *this, symbolRef); 1206 if (!decl) { 1207 return emitOpError() << "expected symbol reference " << symbolRef 1208 << " to point to a critical declaration"; 1209 } 1210 } 1211 1212 return success(); 1213 } 1214 1215 //===----------------------------------------------------------------------===// 1216 // Verifier for ordered construct 1217 //===----------------------------------------------------------------------===// 1218 1219 LogicalResult OrderedOp::verify() { 1220 auto container = (*this)->getParentOfType<WsLoopOp>(); 1221 if (!container || !container.ordered_valAttr() || 1222 container.ordered_valAttr().getInt() == 0) 1223 return emitOpError() << "ordered depend directive must be closely " 1224 << "nested inside a worksharing-loop with ordered " 1225 << "clause with parameter present"; 1226 1227 if (container.ordered_valAttr().getInt() != 1228 (int64_t)num_loops_val().getValue()) 1229 return emitOpError() << "number of variables in depend clause does not " 1230 << "match number of iteration variables in the " 1231 << "doacross loop"; 1232 1233 return success(); 1234 } 1235 1236 LogicalResult OrderedRegionOp::verify() { 1237 // TODO: The code generation for ordered simd directive is not supported yet. 1238 if (simd()) 1239 return failure(); 1240 1241 if (auto container = (*this)->getParentOfType<WsLoopOp>()) { 1242 if (!container.ordered_valAttr() || 1243 container.ordered_valAttr().getInt() != 0) 1244 return emitOpError() << "ordered region must be closely nested inside " 1245 << "a worksharing-loop region with an ordered " 1246 << "clause without parameter present"; 1247 } 1248 1249 return success(); 1250 } 1251 1252 //===----------------------------------------------------------------------===// 1253 // AtomicReadOp 1254 //===----------------------------------------------------------------------===// 1255 1256 /// Parser for AtomicReadOp 1257 /// 1258 /// operation ::= `omp.atomic.read` atomic-clause-list address `->` result-type 1259 /// address ::= operand `:` type 1260 ParseResult AtomicReadOp::parse(OpAsmParser &parser, OperationState &result) { 1261 OpAsmParser::OperandType x, v; 1262 Type addressType; 1263 SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause}; 1264 SmallVector<int> segments; 1265 1266 if (parser.parseOperand(v) || parser.parseEqual() || parser.parseOperand(x) || 1267 parseClauses(parser, result, clauses, segments) || 1268 parser.parseColonType(addressType) || 1269 parser.resolveOperand(x, addressType, result.operands) || 1270 parser.resolveOperand(v, addressType, result.operands)) 1271 return failure(); 1272 1273 return success(); 1274 } 1275 1276 void AtomicReadOp::print(OpAsmPrinter &p) { 1277 p << " " << v() << " = " << x() << " "; 1278 if (auto mo = memory_order()) 1279 p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") "; 1280 if (hintAttr()) 1281 printSynchronizationHint(p << " ", *this, hintAttr()); 1282 p << ": " << x().getType(); 1283 } 1284 1285 /// Verifier for AtomicReadOp 1286 LogicalResult AtomicReadOp::verify() { 1287 if (auto mo = memory_order()) { 1288 if (*mo == ClauseMemoryOrderKind::acq_rel || 1289 *mo == ClauseMemoryOrderKind::release) { 1290 return emitError( 1291 "memory-order must not be acq_rel or release for atomic reads"); 1292 } 1293 } 1294 if (x() == v()) 1295 return emitError( 1296 "read and write must not be to the same location for atomic reads"); 1297 return verifySynchronizationHint(*this, hint()); 1298 } 1299 1300 //===----------------------------------------------------------------------===// 1301 // AtomicWriteOp 1302 //===----------------------------------------------------------------------===// 1303 1304 /// Parser for AtomicWriteOp 1305 /// 1306 /// operation ::= `omp.atomic.write` atomic-clause-list operands 1307 /// operands ::= address `,` value 1308 /// address ::= operand `:` type 1309 /// value ::= operand `:` type 1310 ParseResult AtomicWriteOp::parse(OpAsmParser &parser, OperationState &result) { 1311 OpAsmParser::OperandType address, value; 1312 Type addrType, valueType; 1313 SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause}; 1314 SmallVector<int> segments; 1315 1316 if (parser.parseOperand(address) || parser.parseEqual() || 1317 parser.parseOperand(value) || 1318 parseClauses(parser, result, clauses, segments) || 1319 parser.parseColonType(addrType) || parser.parseComma() || 1320 parser.parseType(valueType) || 1321 parser.resolveOperand(address, addrType, result.operands) || 1322 parser.resolveOperand(value, valueType, result.operands)) 1323 return failure(); 1324 return success(); 1325 } 1326 1327 void AtomicWriteOp::print(OpAsmPrinter &p) { 1328 p << " " << address() << " = " << value() << " "; 1329 if (auto mo = memory_order()) 1330 p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") "; 1331 if (hintAttr()) 1332 printSynchronizationHint(p, *this, hintAttr()); 1333 p << ": " << address().getType() << ", " << value().getType(); 1334 } 1335 1336 /// Verifier for AtomicWriteOp 1337 LogicalResult AtomicWriteOp::verify() { 1338 if (auto mo = memory_order()) { 1339 if (*mo == ClauseMemoryOrderKind::acq_rel || 1340 *mo == ClauseMemoryOrderKind::acquire) { 1341 return emitError( 1342 "memory-order must not be acq_rel or acquire for atomic writes"); 1343 } 1344 } 1345 return verifySynchronizationHint(*this, hint()); 1346 } 1347 1348 //===----------------------------------------------------------------------===// 1349 // AtomicUpdateOp 1350 //===----------------------------------------------------------------------===// 1351 1352 /// Parser for AtomicUpdateOp 1353 /// 1354 /// operation ::= `omp.atomic.update` atomic-clause-list ssa-id-and-type region 1355 ParseResult AtomicUpdateOp::parse(OpAsmParser &parser, OperationState &result) { 1356 SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause}; 1357 SmallVector<int> segments; 1358 OpAsmParser::OperandType x, expr; 1359 Type xType; 1360 1361 if (parseClauses(parser, result, clauses, segments) || 1362 parser.parseOperand(x) || parser.parseColon() || 1363 parser.parseType(xType) || 1364 parser.resolveOperand(x, xType, result.operands) || 1365 parser.parseRegion(*result.addRegion())) 1366 return failure(); 1367 return success(); 1368 } 1369 1370 void AtomicUpdateOp::print(OpAsmPrinter &p) { 1371 p << " "; 1372 if (auto mo = memory_order()) 1373 p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") "; 1374 if (hintAttr()) 1375 printSynchronizationHint(p, *this, hintAttr()); 1376 p << x() << " : " << x().getType(); 1377 p.printRegion(region()); 1378 } 1379 1380 /// Verifier for AtomicUpdateOp 1381 LogicalResult AtomicUpdateOp::verify() { 1382 if (auto mo = memory_order()) { 1383 if (*mo == ClauseMemoryOrderKind::acq_rel || 1384 *mo == ClauseMemoryOrderKind::acquire) { 1385 return emitError( 1386 "memory-order must not be acq_rel or acquire for atomic updates"); 1387 } 1388 } 1389 1390 if (region().getNumArguments() != 1) 1391 return emitError("the region must accept exactly one argument"); 1392 1393 if (x().getType().cast<PointerLikeType>().getElementType() != 1394 region().getArgument(0).getType()) { 1395 return emitError("the type of the operand must be a pointer type whose " 1396 "element type is the same as that of the region argument"); 1397 } 1398 1399 YieldOp yieldOp = *region().getOps<YieldOp>().begin(); 1400 1401 if (yieldOp.results().size() != 1) 1402 return emitError("only updated value must be returned"); 1403 if (yieldOp.results().front().getType() != region().getArgument(0).getType()) 1404 return emitError("input and yielded value must have the same type"); 1405 return success(); 1406 } 1407 1408 //===----------------------------------------------------------------------===// 1409 // AtomicCaptureOp 1410 //===----------------------------------------------------------------------===// 1411 1412 ParseResult AtomicCaptureOp::parse(OpAsmParser &parser, 1413 OperationState &result) { 1414 SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause}; 1415 SmallVector<int> segments; 1416 if (parseClauses(parser, result, clauses, segments) || 1417 parser.parseRegion(*result.addRegion())) 1418 return failure(); 1419 return success(); 1420 } 1421 1422 void AtomicCaptureOp::print(OpAsmPrinter &p) { 1423 if (memory_order()) 1424 p << "memory_order(" << memory_order() << ") "; 1425 if (hintAttr()) 1426 printSynchronizationHint(p, *this, hintAttr()); 1427 p.printRegion(region()); 1428 } 1429 1430 /// Verifier for AtomicCaptureOp 1431 LogicalResult AtomicCaptureOp::verify() { 1432 Block::OpListType &ops = region().front().getOperations(); 1433 if (ops.size() != 3) 1434 return emitError() 1435 << "expected three operations in omp.atomic.capture region (one " 1436 "terminator, and two atomic ops)"; 1437 auto &firstOp = ops.front(); 1438 auto &secondOp = *ops.getNextNode(firstOp); 1439 auto firstReadStmt = dyn_cast<AtomicReadOp>(firstOp); 1440 auto firstUpdateStmt = dyn_cast<AtomicUpdateOp>(firstOp); 1441 auto secondReadStmt = dyn_cast<AtomicReadOp>(secondOp); 1442 auto secondUpdateStmt = dyn_cast<AtomicUpdateOp>(secondOp); 1443 auto secondWriteStmt = dyn_cast<AtomicWriteOp>(secondOp); 1444 1445 if (!((firstUpdateStmt && secondReadStmt) || 1446 (firstReadStmt && secondUpdateStmt) || 1447 (firstReadStmt && secondWriteStmt))) 1448 return ops.front().emitError() 1449 << "invalid sequence of operations in the capture region"; 1450 if (firstUpdateStmt && secondReadStmt && 1451 firstUpdateStmt.x() != secondReadStmt.x()) 1452 return firstUpdateStmt.emitError() 1453 << "updated variable in omp.atomic.update must be captured in " 1454 "second operation"; 1455 if (firstReadStmt && secondUpdateStmt && 1456 firstReadStmt.x() != secondUpdateStmt.x()) 1457 return firstReadStmt.emitError() 1458 << "captured variable in omp.atomic.read must be updated in second " 1459 "operation"; 1460 if (firstReadStmt && secondWriteStmt && 1461 firstReadStmt.x() != secondWriteStmt.address()) 1462 return firstReadStmt.emitError() 1463 << "captured variable in omp.atomic.read must be updated in " 1464 "second operation"; 1465 return success(); 1466 } 1467 1468 #define GET_ATTRDEF_CLASSES 1469 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" 1470 1471 #define GET_OP_CLASSES 1472 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 1473