1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the OpenMP dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 15 #include "mlir/IR/Attributes.h" 16 #include "mlir/IR/DialectImplementation.h" 17 #include "mlir/IR/OpImplementation.h" 18 #include "mlir/IR/OperationSupport.h" 19 20 #include "llvm/ADT/BitVector.h" 21 #include "llvm/ADT/SmallString.h" 22 #include "llvm/ADT/StringExtras.h" 23 #include "llvm/ADT/StringRef.h" 24 #include "llvm/ADT/StringSwitch.h" 25 #include "llvm/ADT/TypeSwitch.h" 26 #include <cstddef> 27 28 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc" 29 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc" 30 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc" 31 32 using namespace mlir; 33 using namespace mlir::omp; 34 35 namespace { 36 /// Model for pointer-like types that already provide a `getElementType` method. 37 template <typename T> 38 struct PointerLikeModel 39 : public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> { 40 Type getElementType(Type pointer) const { 41 return pointer.cast<T>().getElementType(); 42 } 43 }; 44 } // namespace 45 46 void OpenMPDialect::initialize() { 47 addOperations< 48 #define GET_OP_LIST 49 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 50 >(); 51 addAttributes< 52 #define GET_ATTRDEF_LIST 53 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" 54 >(); 55 56 LLVM::LLVMPointerType::attachInterface< 57 PointerLikeModel<LLVM::LLVMPointerType>>(*getContext()); 58 MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext()); 59 } 60 61 //===----------------------------------------------------------------------===// 62 // ParallelOp 63 //===----------------------------------------------------------------------===// 64 65 void ParallelOp::build(OpBuilder &builder, OperationState &state, 66 ArrayRef<NamedAttribute> attributes) { 67 ParallelOp::build( 68 builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr, 69 /*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(), 70 /*proc_bind_val=*/nullptr); 71 state.addAttributes(attributes); 72 } 73 74 //===----------------------------------------------------------------------===// 75 // Parser and printer for Allocate Clause 76 //===----------------------------------------------------------------------===// 77 78 /// Parse an allocate clause with allocators and a list of operands with types. 79 /// 80 /// allocate-operand-list :: = allocate-operand | 81 /// allocator-operand `,` allocate-operand-list 82 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type 83 /// ssa-id-and-type ::= ssa-id `:` type 84 static ParseResult parseAllocateAndAllocator( 85 OpAsmParser &parser, 86 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate, 87 SmallVectorImpl<Type> &typesAllocate, 88 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator, 89 SmallVectorImpl<Type> &typesAllocator) { 90 91 return parser.parseCommaSeparatedList([&]() -> ParseResult { 92 OpAsmParser::OperandType operand; 93 Type type; 94 if (parser.parseOperand(operand) || parser.parseColonType(type)) 95 return failure(); 96 operandsAllocator.push_back(operand); 97 typesAllocator.push_back(type); 98 if (parser.parseArrow()) 99 return failure(); 100 if (parser.parseOperand(operand) || parser.parseColonType(type)) 101 return failure(); 102 103 operandsAllocate.push_back(operand); 104 typesAllocate.push_back(type); 105 return success(); 106 }); 107 } 108 109 /// Print allocate clause 110 static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, 111 OperandRange varsAllocate, 112 TypeRange typesAllocate, 113 OperandRange varsAllocator, 114 TypeRange typesAllocator) { 115 for (unsigned i = 0; i < varsAllocate.size(); ++i) { 116 std::string separator = i == varsAllocate.size() - 1 ? "" : ", "; 117 p << varsAllocator[i] << " : " << typesAllocator[i] << " -> "; 118 p << varsAllocate[i] << " : " << typesAllocate[i] << separator; 119 } 120 } 121 122 ParseResult parseProcBindKind(OpAsmParser &parser, 123 omp::ClauseProcBindKindAttr &procBindAttr) { 124 StringRef procBindStr; 125 if (parser.parseKeyword(&procBindStr)) 126 return failure(); 127 if (auto procBindVal = symbolizeClauseProcBindKind(procBindStr)) { 128 procBindAttr = 129 ClauseProcBindKindAttr::get(parser.getContext(), *procBindVal); 130 return success(); 131 } 132 return failure(); 133 } 134 135 void printProcBindKind(OpAsmPrinter &p, Operation *op, 136 omp::ClauseProcBindKindAttr procBindAttr) { 137 p << stringifyClauseProcBindKind(procBindAttr.getValue()); 138 } 139 140 LogicalResult ParallelOp::verify() { 141 if (allocate_vars().size() != allocators_vars().size()) 142 return emitError( 143 "expected equal sizes for allocate and allocator variables"); 144 return success(); 145 } 146 147 //===----------------------------------------------------------------------===// 148 // Parser and printer for Linear Clause 149 //===----------------------------------------------------------------------===// 150 151 /// linear ::= `linear` `(` linear-list `)` 152 /// linear-list := linear-val | linear-val linear-list 153 /// linear-val := ssa-id-and-type `=` ssa-id-and-type 154 static ParseResult 155 parseLinearClause(OpAsmParser &parser, 156 SmallVectorImpl<OpAsmParser::OperandType> &vars, 157 SmallVectorImpl<Type> &types, 158 SmallVectorImpl<OpAsmParser::OperandType> &stepVars) { 159 if (parser.parseLParen()) 160 return failure(); 161 162 do { 163 OpAsmParser::OperandType var; 164 Type type; 165 OpAsmParser::OperandType stepVar; 166 if (parser.parseOperand(var) || parser.parseEqual() || 167 parser.parseOperand(stepVar) || parser.parseColonType(type)) 168 return failure(); 169 170 vars.push_back(var); 171 types.push_back(type); 172 stepVars.push_back(stepVar); 173 } while (succeeded(parser.parseOptionalComma())); 174 175 if (parser.parseRParen()) 176 return failure(); 177 178 return success(); 179 } 180 181 /// Print Linear Clause 182 static void printLinearClause(OpAsmPrinter &p, OperandRange linearVars, 183 OperandRange linearStepVars) { 184 size_t linearVarsSize = linearVars.size(); 185 p << "linear("; 186 for (unsigned i = 0; i < linearVarsSize; ++i) { 187 std::string separator = i == linearVarsSize - 1 ? ") " : ", "; 188 p << linearVars[i]; 189 if (linearStepVars.size() > i) 190 p << " = " << linearStepVars[i]; 191 p << " : " << linearVars[i].getType() << separator; 192 } 193 } 194 195 //===----------------------------------------------------------------------===// 196 // Parser and printer for Schedule Clause 197 //===----------------------------------------------------------------------===// 198 199 static ParseResult 200 verifyScheduleModifiers(OpAsmParser &parser, 201 SmallVectorImpl<SmallString<12>> &modifiers) { 202 if (modifiers.size() > 2) 203 return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)"; 204 for (const auto &mod : modifiers) { 205 // Translate the string. If it has no value, then it was not a valid 206 // modifier! 207 auto symbol = symbolizeScheduleModifier(mod); 208 if (!symbol.hasValue()) 209 return parser.emitError(parser.getNameLoc()) 210 << " unknown modifier type: " << mod; 211 } 212 213 // If we have one modifier that is "simd", then stick a "none" modiifer in 214 // index 0. 215 if (modifiers.size() == 1) { 216 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) { 217 modifiers.push_back(modifiers[0]); 218 modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none); 219 } 220 } else if (modifiers.size() == 2) { 221 // If there are two modifier: 222 // First modifier should not be simd, second one should be simd 223 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd || 224 symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd) 225 return parser.emitError(parser.getNameLoc()) 226 << " incorrect modifier order"; 227 } 228 return success(); 229 } 230 231 /// schedule ::= `schedule` `(` sched-list `)` 232 /// sched-list ::= sched-val | sched-val sched-list | 233 /// sched-val `,` sched-modifier 234 /// sched-val ::= sched-with-chunk | sched-wo-chunk 235 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)? 236 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided` 237 /// sched-wo-chunk ::= `auto` | `runtime` 238 /// sched-modifier ::= sched-mod-val | sched-mod-val `,` sched-mod-val 239 /// sched-mod-val ::= `monotonic` | `nonmonotonic` | `simd` | `none` 240 static ParseResult 241 parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule, 242 SmallVectorImpl<SmallString<12>> &modifiers, 243 Optional<OpAsmParser::OperandType> &chunkSize, 244 Type &chunkType) { 245 if (parser.parseLParen()) 246 return failure(); 247 248 StringRef keyword; 249 if (parser.parseKeyword(&keyword)) 250 return failure(); 251 252 schedule = keyword; 253 if (keyword == "static" || keyword == "dynamic" || keyword == "guided") { 254 if (succeeded(parser.parseOptionalEqual())) { 255 chunkSize = OpAsmParser::OperandType{}; 256 if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType)) 257 return failure(); 258 } else { 259 chunkSize = llvm::NoneType::None; 260 } 261 } else if (keyword == "auto" || keyword == "runtime") { 262 chunkSize = llvm::NoneType::None; 263 } else { 264 return parser.emitError(parser.getNameLoc()) << " expected schedule kind"; 265 } 266 267 // If there is a comma, we have one or more modifiers.. 268 while (succeeded(parser.parseOptionalComma())) { 269 StringRef mod; 270 if (parser.parseKeyword(&mod)) 271 return failure(); 272 modifiers.push_back(mod); 273 } 274 275 if (parser.parseRParen()) 276 return failure(); 277 278 if (verifyScheduleModifiers(parser, modifiers)) 279 return failure(); 280 281 return success(); 282 } 283 284 /// Print schedule clause 285 static void printScheduleClause(OpAsmPrinter &p, ClauseScheduleKind sched, 286 Optional<ScheduleModifier> modifier, bool simd, 287 Value scheduleChunkVar) { 288 p << "schedule(" << stringifyClauseScheduleKind(sched).lower(); 289 if (scheduleChunkVar) 290 p << " = " << scheduleChunkVar << " : " << scheduleChunkVar.getType(); 291 if (modifier) 292 p << ", " << stringifyScheduleModifier(*modifier); 293 if (simd) 294 p << ", simd"; 295 p << ") "; 296 } 297 298 //===----------------------------------------------------------------------===// 299 // Parser, printer and verifier for ReductionVarList 300 //===----------------------------------------------------------------------===// 301 302 /// reduction-entry-list ::= reduction-entry 303 /// | reduction-entry-list `,` reduction-entry 304 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type 305 static ParseResult parseReductionVarList( 306 OpAsmParser &parser, SmallVectorImpl<OpAsmParser::OperandType> &operands, 307 SmallVectorImpl<Type> &types, ArrayAttr &redcuctionSymbols) { 308 SmallVector<SymbolRefAttr> reductionVec; 309 do { 310 if (parser.parseAttribute(reductionVec.emplace_back()) || 311 parser.parseArrow() || parser.parseOperand(operands.emplace_back()) || 312 parser.parseColonType(types.emplace_back())) 313 return failure(); 314 } while (succeeded(parser.parseOptionalComma())); 315 SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end()); 316 redcuctionSymbols = ArrayAttr::get(parser.getContext(), reductions); 317 return success(); 318 } 319 320 /// Print Reduction clause 321 static void printReductionVarList(OpAsmPrinter &p, Operation *op, 322 OperandRange reductionVars, 323 TypeRange reductionTypes, 324 Optional<ArrayAttr> reductions) { 325 for (unsigned i = 0, e = reductions->size(); i < e; ++i) { 326 if (i != 0) 327 p << ", "; 328 p << (*reductions)[i] << " -> " << reductionVars[i] << " : " 329 << reductionVars[i].getType(); 330 } 331 } 332 333 /// Verifies Reduction Clause 334 static LogicalResult verifyReductionVarList(Operation *op, 335 Optional<ArrayAttr> reductions, 336 OperandRange reductionVars) { 337 if (!reductionVars.empty()) { 338 if (!reductions || reductions->size() != reductionVars.size()) 339 return op->emitOpError() 340 << "expected as many reduction symbol references " 341 "as reduction variables"; 342 } else { 343 if (reductions) 344 return op->emitOpError() << "unexpected reduction symbol references"; 345 return success(); 346 } 347 348 DenseSet<Value> accumulators; 349 for (auto args : llvm::zip(reductionVars, *reductions)) { 350 Value accum = std::get<0>(args); 351 352 if (!accumulators.insert(accum).second) 353 return op->emitOpError() << "accumulator variable used more than once"; 354 355 Type varType = accum.getType().cast<PointerLikeType>(); 356 auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>(); 357 auto decl = 358 SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef); 359 if (!decl) 360 return op->emitOpError() << "expected symbol reference " << symbolRef 361 << " to point to a reduction declaration"; 362 363 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType) 364 return op->emitOpError() 365 << "expected accumulator (" << varType 366 << ") to be the same type as reduction declaration (" 367 << decl.getAccumulatorType() << ")"; 368 } 369 370 return success(); 371 } 372 373 //===----------------------------------------------------------------------===// 374 // Parser, printer and verifier for Synchronization Hint (2.17.12) 375 //===----------------------------------------------------------------------===// 376 377 /// Parses a Synchronization Hint clause. The value of hint is an integer 378 /// which is a combination of different hints from `omp_sync_hint_t`. 379 /// 380 /// hint-clause = `hint` `(` hint-value `)` 381 static ParseResult parseSynchronizationHint(OpAsmParser &parser, 382 IntegerAttr &hintAttr, 383 bool parseKeyword = true) { 384 if (parseKeyword && failed(parser.parseOptionalKeyword("hint"))) { 385 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0); 386 return success(); 387 } 388 389 if (failed(parser.parseLParen())) 390 return failure(); 391 StringRef hintKeyword; 392 int64_t hint = 0; 393 do { 394 if (failed(parser.parseKeyword(&hintKeyword))) 395 return failure(); 396 if (hintKeyword == "uncontended") 397 hint |= 1; 398 else if (hintKeyword == "contended") 399 hint |= 2; 400 else if (hintKeyword == "nonspeculative") 401 hint |= 4; 402 else if (hintKeyword == "speculative") 403 hint |= 8; 404 else 405 return parser.emitError(parser.getCurrentLocation()) 406 << hintKeyword << " is not a valid hint"; 407 } while (succeeded(parser.parseOptionalComma())); 408 if (failed(parser.parseRParen())) 409 return failure(); 410 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint); 411 return success(); 412 } 413 414 /// Prints a Synchronization Hint clause 415 static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, 416 IntegerAttr hintAttr) { 417 int64_t hint = hintAttr.getInt(); 418 419 if (hint == 0) 420 return; 421 422 // Helper function to get n-th bit from the right end of `value` 423 auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; 424 425 bool uncontended = bitn(hint, 0); 426 bool contended = bitn(hint, 1); 427 bool nonspeculative = bitn(hint, 2); 428 bool speculative = bitn(hint, 3); 429 430 SmallVector<StringRef> hints; 431 if (uncontended) 432 hints.push_back("uncontended"); 433 if (contended) 434 hints.push_back("contended"); 435 if (nonspeculative) 436 hints.push_back("nonspeculative"); 437 if (speculative) 438 hints.push_back("speculative"); 439 440 p << "hint("; 441 llvm::interleaveComma(hints, p); 442 p << ") "; 443 } 444 445 /// Verifies a synchronization hint clause 446 static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) { 447 448 // Helper function to get n-th bit from the right end of `value` 449 auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; 450 451 bool uncontended = bitn(hint, 0); 452 bool contended = bitn(hint, 1); 453 bool nonspeculative = bitn(hint, 2); 454 bool speculative = bitn(hint, 3); 455 456 if (uncontended && contended) 457 return op->emitOpError() << "the hints omp_sync_hint_uncontended and " 458 "omp_sync_hint_contended cannot be combined"; 459 if (nonspeculative && speculative) 460 return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and " 461 "omp_sync_hint_speculative cannot be combined."; 462 return success(); 463 } 464 465 enum ClauseType { 466 ifClause, 467 numThreadsClause, 468 deviceClause, 469 threadLimitClause, 470 allocateClause, 471 procBindClause, 472 reductionClause, 473 nowaitClause, 474 linearClause, 475 scheduleClause, 476 collapseClause, 477 orderClause, 478 orderedClause, 479 memoryOrderClause, 480 hintClause, 481 COUNT 482 }; 483 484 //===----------------------------------------------------------------------===// 485 // Parser for Clause List 486 //===----------------------------------------------------------------------===// 487 488 /// Parse a clause attribute `(` $value `)`. 489 template <typename ClauseAttr> 490 static ParseResult parseClauseAttr(AsmParser &parser, OperationState &state, 491 StringRef attrName, StringRef name) { 492 using ClauseT = decltype(std::declval<ClauseAttr>().getValue()); 493 StringRef enumStr; 494 SMLoc loc = parser.getCurrentLocation(); 495 if (parser.parseLParen() || parser.parseKeyword(&enumStr) || 496 parser.parseRParen()) 497 return failure(); 498 if (Optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) { 499 auto attr = ClauseAttr::get(parser.getContext(), *enumValue); 500 state.addAttribute(attrName, attr); 501 return success(); 502 } 503 return parser.emitError(loc, "invalid ") << name << " kind"; 504 } 505 506 /// Parse a list of clauses. The clauses can appear in any order, but their 507 /// operand segment indices are in the same order that they are passed in the 508 /// `clauses` list. The operand segments are added over the prevSegments 509 510 /// clause-list ::= clause clause-list | empty 511 /// clause ::= if | num-threads | allocate | proc-bind | reduction | nowait 512 /// | linear | schedule | collapse | order | ordered | inclusive 513 /// if ::= `if` `(` ssa-id-and-type `)` 514 /// num-threads ::= `num_threads` `(` ssa-id-and-type `)` 515 /// allocate ::= `allocate` `(` allocate-operand-list `)` 516 /// proc-bind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)` 517 /// reduction ::= `reduction` `(` reduction-entry-list `)` 518 /// nowait ::= `nowait` 519 /// linear ::= `linear` `(` linear-list `)` 520 /// schedule ::= `schedule` `(` sched-list `)` 521 /// collapse ::= `collapse` `(` ssa-id-and-type `)` 522 /// order ::= `order` `(` `concurrent` `)` 523 /// ordered ::= `ordered` `(` ssa-id-and-type `)` 524 /// inclusive ::= `inclusive` 525 /// 526 /// Note that each clause can only appear once in the clase-list. 527 static ParseResult parseClauses(OpAsmParser &parser, OperationState &result, 528 SmallVectorImpl<ClauseType> &clauses, 529 SmallVectorImpl<int> &segments) { 530 531 // Check done[clause] to see if it has been parsed already 532 BitVector done(ClauseType::COUNT, false); 533 534 // See pos[clause] to get position of clause in operand segments 535 SmallVector<int> pos(ClauseType::COUNT, -1); 536 537 // Stores the last parsed clause keyword 538 StringRef clauseKeyword; 539 StringRef opName = result.name.getStringRef(); 540 541 // Containers for storing operands, types and attributes for various clauses 542 std::pair<OpAsmParser::OperandType, Type> ifCond; 543 std::pair<OpAsmParser::OperandType, Type> numThreads; 544 std::pair<OpAsmParser::OperandType, Type> device; 545 std::pair<OpAsmParser::OperandType, Type> threadLimit; 546 547 SmallVector<OpAsmParser::OperandType> allocates, allocators; 548 SmallVector<Type> allocateTypes, allocatorTypes; 549 550 ArrayAttr reductions; 551 SmallVector<OpAsmParser::OperandType> reductionVars; 552 SmallVector<Type> reductionVarTypes; 553 554 SmallVector<OpAsmParser::OperandType> linears; 555 SmallVector<Type> linearTypes; 556 SmallVector<OpAsmParser::OperandType> linearSteps; 557 558 SmallString<8> schedule; 559 SmallVector<SmallString<12>> modifiers; 560 Optional<OpAsmParser::OperandType> scheduleChunkSize; 561 Type scheduleChunkType; 562 563 // Compute the position of clauses in operand segments 564 int currPos = 0; 565 for (ClauseType clause : clauses) { 566 567 // Skip the following clauses - they do not take any position in operand 568 // segments 569 if (clause == procBindClause || clause == nowaitClause || 570 clause == collapseClause || clause == orderClause || 571 clause == orderedClause) 572 continue; 573 574 pos[clause] = currPos++; 575 576 // For the following clauses, two positions are reserved in the operand 577 // segments 578 if (clause == allocateClause || clause == linearClause) 579 currPos++; 580 } 581 582 SmallVector<int> clauseSegments(currPos); 583 584 // Helper function to check if a clause is allowed/repeated or not 585 auto checkAllowed = [&](ClauseType clause) -> ParseResult { 586 if (!llvm::is_contained(clauses, clause)) 587 return parser.emitError(parser.getCurrentLocation()) 588 << clauseKeyword << " is not a valid clause for the " << opName 589 << " operation"; 590 if (done[clause]) 591 return parser.emitError(parser.getCurrentLocation()) 592 << "at most one " << clauseKeyword << " clause can appear on the " 593 << opName << " operation"; 594 done[clause] = true; 595 return success(); 596 }; 597 598 while (succeeded(parser.parseOptionalKeyword(&clauseKeyword))) { 599 if (clauseKeyword == "if") { 600 if (checkAllowed(ifClause) || parser.parseLParen() || 601 parser.parseOperand(ifCond.first) || 602 parser.parseColonType(ifCond.second) || parser.parseRParen()) 603 return failure(); 604 clauseSegments[pos[ifClause]] = 1; 605 } else if (clauseKeyword == "num_threads") { 606 if (checkAllowed(numThreadsClause) || parser.parseLParen() || 607 parser.parseOperand(numThreads.first) || 608 parser.parseColonType(numThreads.second) || parser.parseRParen()) 609 return failure(); 610 clauseSegments[pos[numThreadsClause]] = 1; 611 } else if (clauseKeyword == "device") { 612 if (checkAllowed(deviceClause) || parser.parseLParen() || 613 parser.parseOperand(device.first) || 614 parser.parseColonType(device.second) || parser.parseRParen()) 615 return failure(); 616 clauseSegments[pos[deviceClause]] = 1; 617 } else if (clauseKeyword == "thread_limit") { 618 if (checkAllowed(threadLimitClause) || parser.parseLParen() || 619 parser.parseOperand(threadLimit.first) || 620 parser.parseColonType(threadLimit.second) || parser.parseRParen()) 621 return failure(); 622 clauseSegments[pos[threadLimitClause]] = 1; 623 } else if (clauseKeyword == "allocate") { 624 if (checkAllowed(allocateClause) || parser.parseLParen() || 625 parseAllocateAndAllocator(parser, allocates, allocateTypes, 626 allocators, allocatorTypes) || 627 parser.parseRParen()) 628 return failure(); 629 clauseSegments[pos[allocateClause]] = allocates.size(); 630 clauseSegments[pos[allocateClause] + 1] = allocators.size(); 631 } else if (clauseKeyword == "proc_bind") { 632 if (checkAllowed(procBindClause) || 633 parseClauseAttr<ClauseProcBindKindAttr>(parser, result, 634 "proc_bind_val", "proc bind")) 635 return failure(); 636 } else if (clauseKeyword == "reduction") { 637 if (checkAllowed(reductionClause) || parser.parseLParen() || 638 parseReductionVarList(parser, reductionVars, reductionVarTypes, 639 reductions) || 640 parser.parseRParen()) 641 return failure(); 642 clauseSegments[pos[reductionClause]] = reductionVars.size(); 643 } else if (clauseKeyword == "nowait") { 644 if (checkAllowed(nowaitClause)) 645 return failure(); 646 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 647 result.addAttribute("nowait", attr); 648 } else if (clauseKeyword == "linear") { 649 if (checkAllowed(linearClause) || 650 parseLinearClause(parser, linears, linearTypes, linearSteps)) 651 return failure(); 652 clauseSegments[pos[linearClause]] = linears.size(); 653 clauseSegments[pos[linearClause] + 1] = linearSteps.size(); 654 } else if (clauseKeyword == "schedule") { 655 if (checkAllowed(scheduleClause) || 656 parseScheduleClause(parser, schedule, modifiers, scheduleChunkSize, 657 scheduleChunkType)) 658 return failure(); 659 if (scheduleChunkSize) { 660 clauseSegments[pos[scheduleClause]] = 1; 661 } 662 } else if (clauseKeyword == "collapse") { 663 auto type = parser.getBuilder().getI64Type(); 664 mlir::IntegerAttr attr; 665 if (checkAllowed(collapseClause) || parser.parseLParen() || 666 parser.parseAttribute(attr, type) || parser.parseRParen()) 667 return failure(); 668 result.addAttribute("collapse_val", attr); 669 } else if (clauseKeyword == "ordered") { 670 mlir::IntegerAttr attr; 671 if (checkAllowed(orderedClause)) 672 return failure(); 673 if (succeeded(parser.parseOptionalLParen())) { 674 auto type = parser.getBuilder().getI64Type(); 675 if (parser.parseAttribute(attr, type) || parser.parseRParen()) 676 return failure(); 677 } else { 678 // Use 0 to represent no ordered parameter was specified 679 attr = parser.getBuilder().getI64IntegerAttr(0); 680 } 681 result.addAttribute("ordered_val", attr); 682 } else if (clauseKeyword == "order") { 683 if (checkAllowed(orderClause) || 684 parseClauseAttr<ClauseOrderKindAttr>(parser, result, "order_val", 685 "order")) 686 return failure(); 687 } else if (clauseKeyword == "memory_order") { 688 if (checkAllowed(memoryOrderClause) || 689 parseClauseAttr<ClauseMemoryOrderKindAttr>( 690 parser, result, "memory_order", "memory order")) 691 return failure(); 692 } else if (clauseKeyword == "hint") { 693 IntegerAttr hint; 694 if (checkAllowed(hintClause) || 695 parseSynchronizationHint(parser, hint, false)) 696 return failure(); 697 result.addAttribute("hint", hint); 698 } else { 699 return parser.emitError(parser.getNameLoc()) 700 << clauseKeyword << " is not a valid clause"; 701 } 702 } 703 704 // Add if parameter. 705 if (done[ifClause] && clauseSegments[pos[ifClause]] && 706 failed( 707 parser.resolveOperand(ifCond.first, ifCond.second, result.operands))) 708 return failure(); 709 710 // Add num_threads parameter. 711 if (done[numThreadsClause] && clauseSegments[pos[numThreadsClause]] && 712 failed(parser.resolveOperand(numThreads.first, numThreads.second, 713 result.operands))) 714 return failure(); 715 716 // Add device parameter. 717 if (done[deviceClause] && clauseSegments[pos[deviceClause]] && 718 failed( 719 parser.resolveOperand(device.first, device.second, result.operands))) 720 return failure(); 721 722 // Add thread_limit parameter. 723 if (done[threadLimitClause] && clauseSegments[pos[threadLimitClause]] && 724 failed(parser.resolveOperand(threadLimit.first, threadLimit.second, 725 result.operands))) 726 return failure(); 727 728 // Add allocate parameters. 729 if (done[allocateClause] && clauseSegments[pos[allocateClause]] && 730 failed(parser.resolveOperands(allocates, allocateTypes, 731 allocates[0].location, result.operands))) 732 return failure(); 733 734 // Add allocator parameters. 735 if (done[allocateClause] && clauseSegments[pos[allocateClause] + 1] && 736 failed(parser.resolveOperands(allocators, allocatorTypes, 737 allocators[0].location, result.operands))) 738 return failure(); 739 740 // Add reduction parameters and symbols 741 if (done[reductionClause] && clauseSegments[pos[reductionClause]]) { 742 if (failed(parser.resolveOperands(reductionVars, reductionVarTypes, 743 parser.getNameLoc(), result.operands))) 744 return failure(); 745 result.addAttribute("reductions", reductions); 746 } 747 748 // Add linear parameters 749 if (done[linearClause] && clauseSegments[pos[linearClause]]) { 750 auto linearStepType = parser.getBuilder().getI32Type(); 751 SmallVector<Type> linearStepTypes(linearSteps.size(), linearStepType); 752 if (failed(parser.resolveOperands(linears, linearTypes, linears[0].location, 753 result.operands)) || 754 failed(parser.resolveOperands(linearSteps, linearStepTypes, 755 linearSteps[0].location, 756 result.operands))) 757 return failure(); 758 } 759 760 // Add schedule parameters 761 if (done[scheduleClause] && !schedule.empty()) { 762 schedule[0] = llvm::toUpper(schedule[0]); 763 if (Optional<ClauseScheduleKind> sched = 764 symbolizeClauseScheduleKind(schedule)) { 765 auto attr = ClauseScheduleKindAttr::get(parser.getContext(), *sched); 766 result.addAttribute("schedule_val", attr); 767 } else { 768 return parser.emitError(parser.getCurrentLocation(), 769 "invalid schedule kind"); 770 } 771 if (!modifiers.empty()) { 772 SMLoc loc = parser.getCurrentLocation(); 773 if (Optional<ScheduleModifier> mod = 774 symbolizeScheduleModifier(modifiers[0])) { 775 result.addAttribute( 776 "schedule_modifier", 777 ScheduleModifierAttr::get(parser.getContext(), *mod)); 778 } else { 779 return parser.emitError(loc, "invalid schedule modifier"); 780 } 781 // Only SIMD attribute is allowed here! 782 if (modifiers.size() > 1) { 783 assert(symbolizeScheduleModifier(modifiers[1]) == 784 ScheduleModifier::simd); 785 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 786 result.addAttribute("simd_modifier", attr); 787 } 788 } 789 if (scheduleChunkSize) 790 parser.resolveOperand(*scheduleChunkSize, scheduleChunkType, 791 result.operands); 792 } 793 794 segments.insert(segments.end(), clauseSegments.begin(), clauseSegments.end()); 795 796 return success(); 797 } 798 799 //===----------------------------------------------------------------------===// 800 // Verifier for SectionsOp 801 //===----------------------------------------------------------------------===// 802 803 LogicalResult SectionsOp::verify() { 804 if (allocate_vars().size() != allocators_vars().size()) 805 return emitError( 806 "expected equal sizes for allocate and allocator variables"); 807 808 for (auto &inst : *region().begin()) { 809 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) { 810 return emitOpError() 811 << "expected omp.section op or terminator op inside region"; 812 } 813 } 814 815 return verifyReductionVarList(*this, reductions(), reduction_vars()); 816 } 817 818 /// Parses an OpenMP Workshare Loop operation 819 /// 820 /// wsloop ::= `omp.wsloop` loop-control clause-list 821 /// loop-control ::= `(` ssa-id-list `)` `:` type `=` loop-bounds 822 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` inclusive? steps 823 /// steps := `step` `(`ssa-id-list`)` 824 /// clause-list ::= clause clause-list | empty 825 /// clause ::= linear | schedule | collapse | nowait | ordered | order 826 /// | reduction 827 ParseResult WsLoopOp::parse(OpAsmParser &parser, OperationState &result) { 828 // Parse an opening `(` followed by induction variables followed by `)` 829 SmallVector<OpAsmParser::OperandType> ivs; 830 if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1, 831 OpAsmParser::Delimiter::Paren)) 832 return failure(); 833 834 int numIVs = static_cast<int>(ivs.size()); 835 Type loopVarType; 836 if (parser.parseColonType(loopVarType)) 837 return failure(); 838 839 // Parse loop bounds. 840 SmallVector<OpAsmParser::OperandType> lower; 841 if (parser.parseEqual() || 842 parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) || 843 parser.resolveOperands(lower, loopVarType, result.operands)) 844 return failure(); 845 846 SmallVector<OpAsmParser::OperandType> upper; 847 if (parser.parseKeyword("to") || 848 parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) || 849 parser.resolveOperands(upper, loopVarType, result.operands)) 850 return failure(); 851 852 if (succeeded(parser.parseOptionalKeyword("inclusive"))) { 853 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 854 result.addAttribute("inclusive", attr); 855 } 856 857 // Parse step values. 858 SmallVector<OpAsmParser::OperandType> steps; 859 if (parser.parseKeyword("step") || 860 parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) || 861 parser.resolveOperands(steps, loopVarType, result.operands)) 862 return failure(); 863 864 SmallVector<ClauseType> clauses = { 865 linearClause, reductionClause, collapseClause, orderClause, 866 orderedClause, nowaitClause, scheduleClause}; 867 SmallVector<int> segments{numIVs, numIVs, numIVs}; 868 if (failed(parseClauses(parser, result, clauses, segments))) 869 return failure(); 870 871 result.addAttribute("operand_segment_sizes", 872 parser.getBuilder().getI32VectorAttr(segments)); 873 874 // Now parse the body. 875 Region *body = result.addRegion(); 876 SmallVector<Type> ivTypes(numIVs, loopVarType); 877 SmallVector<OpAsmParser::OperandType> blockArgs(ivs); 878 if (parser.parseRegion(*body, blockArgs, ivTypes)) 879 return failure(); 880 return success(); 881 } 882 883 void WsLoopOp::print(OpAsmPrinter &p) { 884 auto args = getRegion().front().getArguments(); 885 p << " (" << args << ") : " << args[0].getType() << " = (" << lowerBound() 886 << ") to (" << upperBound() << ") "; 887 if (inclusive()) { 888 p << "inclusive "; 889 } 890 p << "step (" << step() << ") "; 891 892 if (!linear_vars().empty()) 893 printLinearClause(p, linear_vars(), linear_step_vars()); 894 895 if (auto sched = schedule_val()) 896 printScheduleClause(p, sched.getValue(), schedule_modifier(), 897 simd_modifier(), schedule_chunk_var()); 898 899 if (auto collapse = collapse_val()) 900 p << "collapse(" << collapse << ") "; 901 902 if (nowait()) 903 p << "nowait "; 904 905 if (auto ordered = ordered_val()) 906 p << "ordered(" << ordered << ") "; 907 908 if (auto order = order_val()) 909 p << "order(" << stringifyClauseOrderKind(*order) << ") "; 910 911 if (!reduction_vars().empty()) { 912 printReductionVarList(p << "reduction(", *this, reduction_vars(), 913 reduction_vars().getTypes(), reductions()); 914 p << ")"; 915 } 916 917 p << ' '; 918 p.printRegion(region(), /*printEntryBlockArgs=*/false); 919 } 920 921 //===----------------------------------------------------------------------===// 922 // ReductionOp 923 //===----------------------------------------------------------------------===// 924 925 static ParseResult parseAtomicReductionRegion(OpAsmParser &parser, 926 Region ®ion) { 927 if (parser.parseOptionalKeyword("atomic")) 928 return success(); 929 return parser.parseRegion(region); 930 } 931 932 static void printAtomicReductionRegion(OpAsmPrinter &printer, 933 ReductionDeclareOp op, Region ®ion) { 934 if (region.empty()) 935 return; 936 printer << "atomic "; 937 printer.printRegion(region); 938 } 939 940 LogicalResult ReductionDeclareOp::verify() { 941 if (initializerRegion().empty()) 942 return emitOpError() << "expects non-empty initializer region"; 943 Block &initializerEntryBlock = initializerRegion().front(); 944 if (initializerEntryBlock.getNumArguments() != 1 || 945 initializerEntryBlock.getArgument(0).getType() != type()) { 946 return emitOpError() << "expects initializer region with one argument " 947 "of the reduction type"; 948 } 949 950 for (YieldOp yieldOp : initializerRegion().getOps<YieldOp>()) { 951 if (yieldOp.results().size() != 1 || 952 yieldOp.results().getTypes()[0] != type()) 953 return emitOpError() << "expects initializer region to yield a value " 954 "of the reduction type"; 955 } 956 957 if (reductionRegion().empty()) 958 return emitOpError() << "expects non-empty reduction region"; 959 Block &reductionEntryBlock = reductionRegion().front(); 960 if (reductionEntryBlock.getNumArguments() != 2 || 961 reductionEntryBlock.getArgumentTypes()[0] != 962 reductionEntryBlock.getArgumentTypes()[1] || 963 reductionEntryBlock.getArgumentTypes()[0] != type()) 964 return emitOpError() << "expects reduction region with two arguments of " 965 "the reduction type"; 966 for (YieldOp yieldOp : reductionRegion().getOps<YieldOp>()) { 967 if (yieldOp.results().size() != 1 || 968 yieldOp.results().getTypes()[0] != type()) 969 return emitOpError() << "expects reduction region to yield a value " 970 "of the reduction type"; 971 } 972 973 if (atomicReductionRegion().empty()) 974 return success(); 975 976 Block &atomicReductionEntryBlock = atomicReductionRegion().front(); 977 if (atomicReductionEntryBlock.getNumArguments() != 2 || 978 atomicReductionEntryBlock.getArgumentTypes()[0] != 979 atomicReductionEntryBlock.getArgumentTypes()[1]) 980 return emitOpError() << "expects atomic reduction region with two " 981 "arguments of the same type"; 982 auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0] 983 .dyn_cast<PointerLikeType>(); 984 if (!ptrType || ptrType.getElementType() != type()) 985 return emitOpError() << "expects atomic reduction region arguments to " 986 "be accumulators containing the reduction type"; 987 return success(); 988 } 989 990 LogicalResult ReductionOp::verify() { 991 // TODO: generalize this to an op interface when there is more than one op 992 // that supports reductions. 993 auto container = (*this)->getParentOfType<WsLoopOp>(); 994 for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i) 995 if (container.reduction_vars()[i] == accumulator()) 996 return success(); 997 998 return emitOpError() << "the accumulator is not used by the parent"; 999 } 1000 1001 //===----------------------------------------------------------------------===// 1002 // WsLoopOp 1003 //===----------------------------------------------------------------------===// 1004 1005 void WsLoopOp::build(OpBuilder &builder, OperationState &state, 1006 ValueRange lowerBound, ValueRange upperBound, 1007 ValueRange step, ArrayRef<NamedAttribute> attributes) { 1008 build(builder, state, lowerBound, upperBound, step, 1009 /*linear_vars=*/ValueRange(), 1010 /*linear_step_vars=*/ValueRange(), /*reduction_vars=*/ValueRange(), 1011 /*reductions=*/nullptr, /*schedule_val=*/nullptr, 1012 /*schedule_chunk_var=*/nullptr, /*schedule_modifier=*/nullptr, 1013 /*simd_modifier=*/false, /*collapse_val=*/nullptr, /*nowait=*/false, 1014 /*ordered_val=*/nullptr, /*order_val=*/nullptr, /*inclusive=*/false); 1015 state.addAttributes(attributes); 1016 } 1017 1018 LogicalResult WsLoopOp::verify() { 1019 return verifyReductionVarList(*this, reductions(), reduction_vars()); 1020 } 1021 1022 //===----------------------------------------------------------------------===// 1023 // Verifier for critical construct (2.17.1) 1024 //===----------------------------------------------------------------------===// 1025 1026 LogicalResult CriticalDeclareOp::verify() { 1027 return verifySynchronizationHint(*this, hint()); 1028 } 1029 1030 LogicalResult CriticalOp::verify() { 1031 if (nameAttr()) { 1032 SymbolRefAttr symbolRef = nameAttr(); 1033 auto decl = SymbolTable::lookupNearestSymbolFrom<CriticalDeclareOp>( 1034 *this, symbolRef); 1035 if (!decl) { 1036 return emitOpError() << "expected symbol reference " << symbolRef 1037 << " to point to a critical declaration"; 1038 } 1039 } 1040 1041 return success(); 1042 } 1043 1044 //===----------------------------------------------------------------------===// 1045 // Verifier for ordered construct 1046 //===----------------------------------------------------------------------===// 1047 1048 LogicalResult OrderedOp::verify() { 1049 auto container = (*this)->getParentOfType<WsLoopOp>(); 1050 if (!container || !container.ordered_valAttr() || 1051 container.ordered_valAttr().getInt() == 0) 1052 return emitOpError() << "ordered depend directive must be closely " 1053 << "nested inside a worksharing-loop with ordered " 1054 << "clause with parameter present"; 1055 1056 if (container.ordered_valAttr().getInt() != 1057 (int64_t)num_loops_val().getValue()) 1058 return emitOpError() << "number of variables in depend clause does not " 1059 << "match number of iteration variables in the " 1060 << "doacross loop"; 1061 1062 return success(); 1063 } 1064 1065 LogicalResult OrderedRegionOp::verify() { 1066 // TODO: The code generation for ordered simd directive is not supported yet. 1067 if (simd()) 1068 return failure(); 1069 1070 if (auto container = (*this)->getParentOfType<WsLoopOp>()) { 1071 if (!container.ordered_valAttr() || 1072 container.ordered_valAttr().getInt() != 0) 1073 return emitOpError() << "ordered region must be closely nested inside " 1074 << "a worksharing-loop region with an ordered " 1075 << "clause without parameter present"; 1076 } 1077 1078 return success(); 1079 } 1080 1081 //===----------------------------------------------------------------------===// 1082 // AtomicReadOp 1083 //===----------------------------------------------------------------------===// 1084 1085 /// Parser for AtomicReadOp 1086 /// 1087 /// operation ::= `omp.atomic.read` atomic-clause-list address `->` result-type 1088 /// address ::= operand `:` type 1089 ParseResult AtomicReadOp::parse(OpAsmParser &parser, OperationState &result) { 1090 OpAsmParser::OperandType x, v; 1091 Type addressType; 1092 SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause}; 1093 SmallVector<int> segments; 1094 1095 if (parser.parseOperand(v) || parser.parseEqual() || parser.parseOperand(x) || 1096 parseClauses(parser, result, clauses, segments) || 1097 parser.parseColonType(addressType) || 1098 parser.resolveOperand(x, addressType, result.operands) || 1099 parser.resolveOperand(v, addressType, result.operands)) 1100 return failure(); 1101 1102 return success(); 1103 } 1104 1105 void AtomicReadOp::print(OpAsmPrinter &p) { 1106 p << " " << v() << " = " << x() << " "; 1107 if (auto mo = memory_order()) 1108 p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") "; 1109 if (hintAttr()) 1110 printSynchronizationHint(p << " ", *this, hintAttr()); 1111 p << ": " << x().getType(); 1112 } 1113 1114 /// Verifier for AtomicReadOp 1115 LogicalResult AtomicReadOp::verify() { 1116 if (auto mo = memory_order()) { 1117 if (*mo == ClauseMemoryOrderKind::acq_rel || 1118 *mo == ClauseMemoryOrderKind::release) { 1119 return emitError( 1120 "memory-order must not be acq_rel or release for atomic reads"); 1121 } 1122 } 1123 if (x() == v()) 1124 return emitError( 1125 "read and write must not be to the same location for atomic reads"); 1126 return verifySynchronizationHint(*this, hint()); 1127 } 1128 1129 //===----------------------------------------------------------------------===// 1130 // AtomicWriteOp 1131 //===----------------------------------------------------------------------===// 1132 1133 /// Parser for AtomicWriteOp 1134 /// 1135 /// operation ::= `omp.atomic.write` atomic-clause-list operands 1136 /// operands ::= address `,` value 1137 /// address ::= operand `:` type 1138 /// value ::= operand `:` type 1139 ParseResult AtomicWriteOp::parse(OpAsmParser &parser, OperationState &result) { 1140 OpAsmParser::OperandType address, value; 1141 Type addrType, valueType; 1142 SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause}; 1143 SmallVector<int> segments; 1144 1145 if (parser.parseOperand(address) || parser.parseEqual() || 1146 parser.parseOperand(value) || 1147 parseClauses(parser, result, clauses, segments) || 1148 parser.parseColonType(addrType) || parser.parseComma() || 1149 parser.parseType(valueType) || 1150 parser.resolveOperand(address, addrType, result.operands) || 1151 parser.resolveOperand(value, valueType, result.operands)) 1152 return failure(); 1153 return success(); 1154 } 1155 1156 void AtomicWriteOp::print(OpAsmPrinter &p) { 1157 p << " " << address() << " = " << value() << " "; 1158 if (auto mo = memory_order()) 1159 p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") "; 1160 if (hintAttr()) 1161 printSynchronizationHint(p, *this, hintAttr()); 1162 p << ": " << address().getType() << ", " << value().getType(); 1163 } 1164 1165 /// Verifier for AtomicWriteOp 1166 LogicalResult AtomicWriteOp::verify() { 1167 if (auto mo = memory_order()) { 1168 if (*mo == ClauseMemoryOrderKind::acq_rel || 1169 *mo == ClauseMemoryOrderKind::acquire) { 1170 return emitError( 1171 "memory-order must not be acq_rel or acquire for atomic writes"); 1172 } 1173 } 1174 return verifySynchronizationHint(*this, hint()); 1175 } 1176 1177 //===----------------------------------------------------------------------===// 1178 // AtomicUpdateOp 1179 //===----------------------------------------------------------------------===// 1180 1181 /// Parser for AtomicUpdateOp 1182 /// 1183 /// operation ::= `omp.atomic.update` atomic-clause-list ssa-id-and-type region 1184 ParseResult AtomicUpdateOp::parse(OpAsmParser &parser, OperationState &result) { 1185 SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause}; 1186 SmallVector<int> segments; 1187 OpAsmParser::OperandType x, expr; 1188 Type xType; 1189 1190 if (parseClauses(parser, result, clauses, segments) || 1191 parser.parseOperand(x) || parser.parseColon() || 1192 parser.parseType(xType) || 1193 parser.resolveOperand(x, xType, result.operands) || 1194 parser.parseRegion(*result.addRegion())) 1195 return failure(); 1196 return success(); 1197 } 1198 1199 void AtomicUpdateOp::print(OpAsmPrinter &p) { 1200 p << " "; 1201 if (auto mo = memory_order()) 1202 p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") "; 1203 if (hintAttr()) 1204 printSynchronizationHint(p, *this, hintAttr()); 1205 p << x() << " : " << x().getType(); 1206 p.printRegion(region()); 1207 } 1208 1209 /// Verifier for AtomicUpdateOp 1210 LogicalResult AtomicUpdateOp::verify() { 1211 if (auto mo = memory_order()) { 1212 if (*mo == ClauseMemoryOrderKind::acq_rel || 1213 *mo == ClauseMemoryOrderKind::acquire) { 1214 return emitError( 1215 "memory-order must not be acq_rel or acquire for atomic updates"); 1216 } 1217 } 1218 1219 if (region().getNumArguments() != 1) 1220 return emitError("the region must accept exactly one argument"); 1221 1222 if (x().getType().cast<PointerLikeType>().getElementType() != 1223 region().getArgument(0).getType()) { 1224 return emitError("the type of the operand must be a pointer type whose " 1225 "element type is the same as that of the region argument"); 1226 } 1227 1228 YieldOp yieldOp = *region().getOps<YieldOp>().begin(); 1229 1230 if (yieldOp.results().size() != 1) 1231 return emitError("only updated value must be returned"); 1232 if (yieldOp.results().front().getType() != region().getArgument(0).getType()) 1233 return emitError("input and yielded value must have the same type"); 1234 return success(); 1235 } 1236 1237 //===----------------------------------------------------------------------===// 1238 // AtomicCaptureOp 1239 //===----------------------------------------------------------------------===// 1240 1241 ParseResult AtomicCaptureOp::parse(OpAsmParser &parser, 1242 OperationState &result) { 1243 SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause}; 1244 SmallVector<int> segments; 1245 if (parseClauses(parser, result, clauses, segments) || 1246 parser.parseRegion(*result.addRegion())) 1247 return failure(); 1248 return success(); 1249 } 1250 1251 void AtomicCaptureOp::print(OpAsmPrinter &p) { 1252 if (memory_order()) 1253 p << "memory_order(" << memory_order() << ") "; 1254 if (hintAttr()) 1255 printSynchronizationHint(p, *this, hintAttr()); 1256 p.printRegion(region()); 1257 } 1258 1259 /// Verifier for AtomicCaptureOp 1260 LogicalResult AtomicCaptureOp::verify() { 1261 Block::OpListType &ops = region().front().getOperations(); 1262 if (ops.size() != 3) 1263 return emitError() 1264 << "expected three operations in omp.atomic.capture region (one " 1265 "terminator, and two atomic ops)"; 1266 auto &firstOp = ops.front(); 1267 auto &secondOp = *ops.getNextNode(firstOp); 1268 auto firstReadStmt = dyn_cast<AtomicReadOp>(firstOp); 1269 auto firstUpdateStmt = dyn_cast<AtomicUpdateOp>(firstOp); 1270 auto secondReadStmt = dyn_cast<AtomicReadOp>(secondOp); 1271 auto secondUpdateStmt = dyn_cast<AtomicUpdateOp>(secondOp); 1272 auto secondWriteStmt = dyn_cast<AtomicWriteOp>(secondOp); 1273 1274 if (!((firstUpdateStmt && secondReadStmt) || 1275 (firstReadStmt && secondUpdateStmt) || 1276 (firstReadStmt && secondWriteStmt))) 1277 return ops.front().emitError() 1278 << "invalid sequence of operations in the capture region"; 1279 if (firstUpdateStmt && secondReadStmt && 1280 firstUpdateStmt.x() != secondReadStmt.x()) 1281 return firstUpdateStmt.emitError() 1282 << "updated variable in omp.atomic.update must be captured in " 1283 "second operation"; 1284 if (firstReadStmt && secondUpdateStmt && 1285 firstReadStmt.x() != secondUpdateStmt.x()) 1286 return firstReadStmt.emitError() 1287 << "captured variable in omp.atomic.read must be updated in second " 1288 "operation"; 1289 if (firstReadStmt && secondWriteStmt && 1290 firstReadStmt.x() != secondWriteStmt.address()) 1291 return firstReadStmt.emitError() 1292 << "captured variable in omp.atomic.read must be updated in " 1293 "second operation"; 1294 return success(); 1295 } 1296 1297 #define GET_ATTRDEF_CLASSES 1298 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" 1299 1300 #define GET_OP_CLASSES 1301 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 1302