1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the OpenMP dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/IR/Attributes.h" 17 #include "mlir/IR/OpImplementation.h" 18 #include "mlir/IR/OperationSupport.h" 19 20 #include "llvm/ADT/BitVector.h" 21 #include "llvm/ADT/SmallString.h" 22 #include "llvm/ADT/StringExtras.h" 23 #include "llvm/ADT/StringRef.h" 24 #include "llvm/ADT/StringSwitch.h" 25 #include <cstddef> 26 27 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc" 28 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc" 29 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc" 30 31 using namespace mlir; 32 using namespace mlir::omp; 33 34 namespace { 35 /// Model for pointer-like types that already provide a `getElementType` method. 36 template <typename T> 37 struct PointerLikeModel 38 : public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> { 39 Type getElementType(Type pointer) const { 40 return pointer.cast<T>().getElementType(); 41 } 42 }; 43 } // end namespace 44 45 void OpenMPDialect::initialize() { 46 addOperations< 47 #define GET_OP_LIST 48 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 49 >(); 50 51 LLVM::LLVMPointerType::attachInterface< 52 PointerLikeModel<LLVM::LLVMPointerType>>(*getContext()); 53 MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext()); 54 } 55 56 //===----------------------------------------------------------------------===// 57 // ParallelOp 58 //===----------------------------------------------------------------------===// 59 60 void ParallelOp::build(OpBuilder &builder, OperationState &state, 61 ArrayRef<NamedAttribute> attributes) { 62 ParallelOp::build( 63 builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr, 64 /*default_val=*/nullptr, /*private_vars=*/ValueRange(), 65 /*firstprivate_vars=*/ValueRange(), /*shared_vars=*/ValueRange(), 66 /*copyin_vars=*/ValueRange(), /*allocate_vars=*/ValueRange(), 67 /*allocators_vars=*/ValueRange(), /*proc_bind_val=*/nullptr); 68 state.addAttributes(attributes); 69 } 70 71 //===----------------------------------------------------------------------===// 72 // Parser and printer for Operand and type list 73 //===----------------------------------------------------------------------===// 74 75 /// Parse a list of operands with types. 76 /// 77 /// operand-and-type-list ::= `(` ssa-id-and-type-list `)` 78 /// ssa-id-and-type-list ::= ssa-id-and-type | 79 /// ssa-id-and-type `,` ssa-id-and-type-list 80 /// ssa-id-and-type ::= ssa-id `:` type 81 static ParseResult 82 parseOperandAndTypeList(OpAsmParser &parser, 83 SmallVectorImpl<OpAsmParser::OperandType> &operands, 84 SmallVectorImpl<Type> &types) { 85 return parser.parseCommaSeparatedList( 86 OpAsmParser::Delimiter::Paren, [&]() -> ParseResult { 87 OpAsmParser::OperandType operand; 88 Type type; 89 if (parser.parseOperand(operand) || parser.parseColonType(type)) 90 return failure(); 91 operands.push_back(operand); 92 types.push_back(type); 93 return success(); 94 }); 95 } 96 97 /// Print an operand and type list with parentheses 98 static void printOperandAndTypeList(OpAsmPrinter &p, OperandRange operands) { 99 p << "("; 100 llvm::interleaveComma( 101 operands, p, [&](const Value &v) { p << v << " : " << v.getType(); }); 102 p << ") "; 103 } 104 105 /// Print data variables corresponding to a data-sharing clause `name` 106 static void printDataVars(OpAsmPrinter &p, OperandRange operands, 107 StringRef name) { 108 if (operands.size()) { 109 p << name; 110 printOperandAndTypeList(p, operands); 111 } 112 } 113 114 //===----------------------------------------------------------------------===// 115 // Parser and printer for Allocate Clause 116 //===----------------------------------------------------------------------===// 117 118 /// Parse an allocate clause with allocators and a list of operands with types. 119 /// 120 /// allocate ::= `allocate` `(` allocate-operand-list `)` 121 /// allocate-operand-list :: = allocate-operand | 122 /// allocator-operand `,` allocate-operand-list 123 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type 124 /// ssa-id-and-type ::= ssa-id `:` type 125 static ParseResult parseAllocateAndAllocator( 126 OpAsmParser &parser, 127 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate, 128 SmallVectorImpl<Type> &typesAllocate, 129 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator, 130 SmallVectorImpl<Type> &typesAllocator) { 131 132 return parser.parseCommaSeparatedList( 133 OpAsmParser::Delimiter::Paren, [&]() -> ParseResult { 134 OpAsmParser::OperandType operand; 135 Type type; 136 if (parser.parseOperand(operand) || parser.parseColonType(type)) 137 return failure(); 138 operandsAllocator.push_back(operand); 139 typesAllocator.push_back(type); 140 if (parser.parseArrow()) 141 return failure(); 142 if (parser.parseOperand(operand) || parser.parseColonType(type)) 143 return failure(); 144 145 operandsAllocate.push_back(operand); 146 typesAllocate.push_back(type); 147 return success(); 148 }); 149 } 150 151 /// Print allocate clause 152 static void printAllocateAndAllocator(OpAsmPrinter &p, 153 OperandRange varsAllocate, 154 OperandRange varsAllocator) { 155 if (varsAllocate.empty()) 156 return; 157 158 p << "allocate("; 159 for (unsigned i = 0; i < varsAllocate.size(); ++i) { 160 std::string separator = i == varsAllocate.size() - 1 ? ") " : ", "; 161 p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> "; 162 p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator; 163 } 164 } 165 166 static LogicalResult verifyParallelOp(ParallelOp op) { 167 if (op.allocate_vars().size() != op.allocators_vars().size()) 168 return op.emitError( 169 "expected equal sizes for allocate and allocator variables"); 170 return success(); 171 } 172 173 static void printParallelOp(OpAsmPrinter &p, ParallelOp op) { 174 p << " "; 175 if (auto ifCond = op.if_expr_var()) 176 p << "if(" << ifCond << " : " << ifCond.getType() << ") "; 177 178 if (auto threads = op.num_threads_var()) 179 p << "num_threads(" << threads << " : " << threads.getType() << ") "; 180 181 printDataVars(p, op.private_vars(), "private"); 182 printDataVars(p, op.firstprivate_vars(), "firstprivate"); 183 printDataVars(p, op.shared_vars(), "shared"); 184 printDataVars(p, op.copyin_vars(), "copyin"); 185 printAllocateAndAllocator(p, op.allocate_vars(), op.allocators_vars()); 186 187 if (auto def = op.default_val()) 188 p << "default(" << def->drop_front(3) << ") "; 189 190 if (auto bind = op.proc_bind_val()) 191 p << "proc_bind(" << bind << ") "; 192 193 p.printRegion(op.getRegion()); 194 } 195 196 //===----------------------------------------------------------------------===// 197 // Parser and printer for Linear Clause 198 //===----------------------------------------------------------------------===// 199 200 /// linear ::= `linear` `(` linear-list `)` 201 /// linear-list := linear-val | linear-val linear-list 202 /// linear-val := ssa-id-and-type `=` ssa-id-and-type 203 static ParseResult 204 parseLinearClause(OpAsmParser &parser, 205 SmallVectorImpl<OpAsmParser::OperandType> &vars, 206 SmallVectorImpl<Type> &types, 207 SmallVectorImpl<OpAsmParser::OperandType> &stepVars) { 208 if (parser.parseLParen()) 209 return failure(); 210 211 do { 212 OpAsmParser::OperandType var; 213 Type type; 214 OpAsmParser::OperandType stepVar; 215 if (parser.parseOperand(var) || parser.parseEqual() || 216 parser.parseOperand(stepVar) || parser.parseColonType(type)) 217 return failure(); 218 219 vars.push_back(var); 220 types.push_back(type); 221 stepVars.push_back(stepVar); 222 } while (succeeded(parser.parseOptionalComma())); 223 224 if (parser.parseRParen()) 225 return failure(); 226 227 return success(); 228 } 229 230 /// Print Linear Clause 231 static void printLinearClause(OpAsmPrinter &p, OperandRange linearVars, 232 OperandRange linearStepVars) { 233 size_t linearVarsSize = linearVars.size(); 234 p << "("; 235 for (unsigned i = 0; i < linearVarsSize; ++i) { 236 std::string separator = i == linearVarsSize - 1 ? ") " : ", "; 237 p << linearVars[i]; 238 if (linearStepVars.size() > i) 239 p << " = " << linearStepVars[i]; 240 p << " : " << linearVars[i].getType() << separator; 241 } 242 } 243 244 //===----------------------------------------------------------------------===// 245 // Parser and printer for Schedule Clause 246 //===----------------------------------------------------------------------===// 247 248 /// schedule ::= `schedule` `(` sched-list `)` 249 /// sched-list ::= sched-val | sched-val sched-list 250 /// sched-val ::= sched-with-chunk | sched-wo-chunk 251 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)? 252 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided` 253 /// sched-wo-chunk ::= `auto` | `runtime` 254 static ParseResult 255 parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule, 256 SmallVectorImpl<SmallString<12>> &modifiers, 257 Optional<OpAsmParser::OperandType> &chunkSize) { 258 if (parser.parseLParen()) 259 return failure(); 260 261 StringRef keyword; 262 if (parser.parseKeyword(&keyword)) 263 return failure(); 264 265 schedule = keyword; 266 if (keyword == "static" || keyword == "dynamic" || keyword == "guided") { 267 if (succeeded(parser.parseOptionalEqual())) { 268 chunkSize = OpAsmParser::OperandType{}; 269 if (parser.parseOperand(*chunkSize)) 270 return failure(); 271 } else { 272 chunkSize = llvm::NoneType::None; 273 } 274 } else if (keyword == "auto" || keyword == "runtime") { 275 chunkSize = llvm::NoneType::None; 276 } else { 277 return parser.emitError(parser.getNameLoc()) << " expected schedule kind"; 278 } 279 280 // If there is a comma, we have one or more modifiers.. 281 if (succeeded(parser.parseOptionalComma())) { 282 StringRef mod; 283 if (parser.parseKeyword(&mod)) 284 return failure(); 285 modifiers.push_back(mod); 286 } 287 288 if (parser.parseRParen()) 289 return failure(); 290 291 return success(); 292 } 293 294 /// Print schedule clause 295 static void printScheduleClause(OpAsmPrinter &p, StringRef &sched, 296 llvm::Optional<StringRef> modifier, 297 Value scheduleChunkVar) { 298 std::string schedLower = sched.lower(); 299 p << "(" << schedLower; 300 if (scheduleChunkVar) 301 p << " = " << scheduleChunkVar; 302 if (modifier && modifier.getValue() != "none") 303 p << ", " << modifier; 304 p << ") "; 305 } 306 307 //===----------------------------------------------------------------------===// 308 // Parser, printer and verifier for ReductionVarList 309 //===----------------------------------------------------------------------===// 310 311 /// reduction ::= `reduction` `(` reduction-entry-list `)` 312 /// reduction-entry-list ::= reduction-entry 313 /// | reduction-entry-list `,` reduction-entry 314 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type 315 static ParseResult 316 parseReductionVarList(OpAsmParser &parser, 317 SmallVectorImpl<SymbolRefAttr> &symbols, 318 SmallVectorImpl<OpAsmParser::OperandType> &operands, 319 SmallVectorImpl<Type> &types) { 320 if (failed(parser.parseLParen())) 321 return failure(); 322 323 do { 324 if (parser.parseAttribute(symbols.emplace_back()) || parser.parseArrow() || 325 parser.parseOperand(operands.emplace_back()) || 326 parser.parseColonType(types.emplace_back())) 327 return failure(); 328 } while (succeeded(parser.parseOptionalComma())); 329 return parser.parseRParen(); 330 } 331 332 /// Print Reduction clause 333 static void printReductionVarList(OpAsmPrinter &p, 334 Optional<ArrayAttr> reductions, 335 OperandRange reduction_vars) { 336 for (unsigned i = 0, e = reductions->size(); i < e; ++i) { 337 if (i != 0) 338 p << ", "; 339 p << (*reductions)[i] << " -> " << reduction_vars[i] << " : " 340 << reduction_vars[i].getType(); 341 } 342 p << ") "; 343 } 344 345 /// Verifies Reduction Clause 346 static LogicalResult verifyReductionVarList(Operation *op, 347 Optional<ArrayAttr> reductions, 348 OperandRange reduction_vars) { 349 if (reduction_vars.size() != 0) { 350 if (!reductions || reductions->size() != reduction_vars.size()) 351 return op->emitOpError() 352 << "expected as many reduction symbol references " 353 "as reduction variables"; 354 } else { 355 if (reductions) 356 return op->emitOpError() << "unexpected reduction symbol references"; 357 return success(); 358 } 359 360 DenseSet<Value> accumulators; 361 for (auto args : llvm::zip(reduction_vars, *reductions)) { 362 Value accum = std::get<0>(args); 363 364 if (!accumulators.insert(accum).second) 365 return op->emitOpError() << "accumulator variable used more than once"; 366 367 Type varType = accum.getType().cast<PointerLikeType>(); 368 auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>(); 369 auto decl = 370 SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef); 371 if (!decl) 372 return op->emitOpError() << "expected symbol reference " << symbolRef 373 << " to point to a reduction declaration"; 374 375 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType) 376 return op->emitOpError() 377 << "expected accumulator (" << varType 378 << ") to be the same type as reduction declaration (" 379 << decl.getAccumulatorType() << ")"; 380 } 381 382 return success(); 383 } 384 385 //===----------------------------------------------------------------------===// 386 // Parser, printer and verifier for Synchronization Hint (2.17.12) 387 //===----------------------------------------------------------------------===// 388 389 /// Parses a Synchronization Hint clause. The value of hint is an integer 390 /// which is a combination of different hints from `omp_sync_hint_t`. 391 /// 392 /// hint-clause = `hint` `(` hint-value `)` 393 static ParseResult parseSynchronizationHint(OpAsmParser &parser, 394 IntegerAttr &hintAttr, 395 bool parseKeyword = true) { 396 if (parseKeyword && failed(parser.parseOptionalKeyword("hint"))) { 397 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0); 398 return success(); 399 } 400 401 if (failed(parser.parseLParen())) 402 return failure(); 403 StringRef hintKeyword; 404 int64_t hint = 0; 405 do { 406 if (failed(parser.parseKeyword(&hintKeyword))) 407 return failure(); 408 if (hintKeyword == "uncontended") 409 hint |= 1; 410 else if (hintKeyword == "contended") 411 hint |= 2; 412 else if (hintKeyword == "nonspeculative") 413 hint |= 4; 414 else if (hintKeyword == "speculative") 415 hint |= 8; 416 else 417 return parser.emitError(parser.getCurrentLocation()) 418 << hintKeyword << " is not a valid hint"; 419 } while (succeeded(parser.parseOptionalComma())); 420 if (failed(parser.parseRParen())) 421 return failure(); 422 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint); 423 return success(); 424 } 425 426 /// Prints a Synchronization Hint clause 427 static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, 428 IntegerAttr hintAttr) { 429 int64_t hint = hintAttr.getInt(); 430 431 if (hint == 0) 432 return; 433 434 // Helper function to get n-th bit from the right end of `value` 435 auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; 436 437 bool uncontended = bitn(hint, 0); 438 bool contended = bitn(hint, 1); 439 bool nonspeculative = bitn(hint, 2); 440 bool speculative = bitn(hint, 3); 441 442 SmallVector<StringRef> hints; 443 if (uncontended) 444 hints.push_back("uncontended"); 445 if (contended) 446 hints.push_back("contended"); 447 if (nonspeculative) 448 hints.push_back("nonspeculative"); 449 if (speculative) 450 hints.push_back("speculative"); 451 452 p << "hint("; 453 llvm::interleaveComma(hints, p); 454 p << ") "; 455 } 456 457 /// Verifies a synchronization hint clause 458 static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) { 459 460 // Helper function to get n-th bit from the right end of `value` 461 auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; 462 463 bool uncontended = bitn(hint, 0); 464 bool contended = bitn(hint, 1); 465 bool nonspeculative = bitn(hint, 2); 466 bool speculative = bitn(hint, 3); 467 468 if (uncontended && contended) 469 return op->emitOpError() << "the hints omp_sync_hint_uncontended and " 470 "omp_sync_hint_contended cannot be combined"; 471 if (nonspeculative && speculative) 472 return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and " 473 "omp_sync_hint_speculative cannot be combined."; 474 return success(); 475 } 476 477 enum ClauseType { 478 ifClause, 479 numThreadsClause, 480 privateClause, 481 firstprivateClause, 482 lastprivateClause, 483 sharedClause, 484 copyinClause, 485 allocateClause, 486 defaultClause, 487 procBindClause, 488 reductionClause, 489 nowaitClause, 490 linearClause, 491 scheduleClause, 492 collapseClause, 493 orderClause, 494 orderedClause, 495 memoryOrderClause, 496 hintClause, 497 COUNT 498 }; 499 500 //===----------------------------------------------------------------------===// 501 // Parser for Clause List 502 //===----------------------------------------------------------------------===// 503 504 /// Parse a list of clauses. The clauses can appear in any order, but their 505 /// operand segment indices are in the same order that they are passed in the 506 /// `clauses` list. The operand segments are added over the prevSegments 507 508 /// clause-list ::= clause clause-list | empty 509 /// clause ::= if | num-threads | private | firstprivate | lastprivate | 510 /// shared | copyin | allocate | default | proc-bind | reduction | 511 /// nowait | linear | schedule | collapse | order | ordered | 512 /// inclusive 513 /// if ::= `if` `(` ssa-id-and-type `)` 514 /// num-threads ::= `num_threads` `(` ssa-id-and-type `)` 515 /// private ::= `private` operand-and-type-list 516 /// firstprivate ::= `firstprivate` operand-and-type-list 517 /// lastprivate ::= `lastprivate` operand-and-type-list 518 /// shared ::= `shared` operand-and-type-list 519 /// copyin ::= `copyin` operand-and-type-list 520 /// allocate ::= `allocate` `(` allocate-operand-list `)` 521 /// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`) 522 /// proc-bind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)` 523 /// reduction ::= `reduction` `(` reduction-entry-list `)` 524 /// nowait ::= `nowait` 525 /// linear ::= `linear` `(` linear-list `)` 526 /// schedule ::= `schedule` `(` sched-list `)` 527 /// collapse ::= `collapse` `(` ssa-id-and-type `)` 528 /// order ::= `order` `(` `concurrent` `)` 529 /// ordered ::= `ordered` `(` ssa-id-and-type `)` 530 /// inclusive ::= `inclusive` 531 /// 532 /// Note that each clause can only appear once in the clase-list. 533 static ParseResult parseClauses(OpAsmParser &parser, OperationState &result, 534 SmallVectorImpl<ClauseType> &clauses, 535 SmallVectorImpl<int> &segments) { 536 537 // Check done[clause] to see if it has been parsed already 538 llvm::BitVector done(ClauseType::COUNT, false); 539 540 // See pos[clause] to get position of clause in operand segments 541 SmallVector<int> pos(ClauseType::COUNT, -1); 542 543 // Stores the last parsed clause keyword 544 StringRef clauseKeyword; 545 StringRef opName = result.name.getStringRef(); 546 547 // Containers for storing operands, types and attributes for various clauses 548 std::pair<OpAsmParser::OperandType, Type> ifCond; 549 std::pair<OpAsmParser::OperandType, Type> numThreads; 550 551 SmallVector<OpAsmParser::OperandType> privates, firstprivates, lastprivates, 552 shareds, copyins; 553 SmallVector<Type> privateTypes, firstprivateTypes, lastprivateTypes, 554 sharedTypes, copyinTypes; 555 556 SmallVector<OpAsmParser::OperandType> allocates, allocators; 557 SmallVector<Type> allocateTypes, allocatorTypes; 558 559 SmallVector<SymbolRefAttr> reductionSymbols; 560 SmallVector<OpAsmParser::OperandType> reductionVars; 561 SmallVector<Type> reductionVarTypes; 562 563 SmallVector<OpAsmParser::OperandType> linears; 564 SmallVector<Type> linearTypes; 565 SmallVector<OpAsmParser::OperandType> linearSteps; 566 567 SmallString<8> schedule; 568 SmallVector<SmallString<12>> modifiers; 569 Optional<OpAsmParser::OperandType> scheduleChunkSize; 570 571 // Compute the position of clauses in operand segments 572 int currPos = 0; 573 for (ClauseType clause : clauses) { 574 575 // Skip the following clauses - they do not take any position in operand 576 // segments 577 if (clause == defaultClause || clause == procBindClause || 578 clause == nowaitClause || clause == collapseClause || 579 clause == orderClause || clause == orderedClause) 580 continue; 581 582 pos[clause] = currPos++; 583 584 // For the following clauses, two positions are reserved in the operand 585 // segments 586 if (clause == allocateClause || clause == linearClause) 587 currPos++; 588 } 589 590 SmallVector<int> clauseSegments(currPos); 591 592 // Helper function to check if a clause is allowed/repeated or not 593 auto checkAllowed = [&](ClauseType clause, 594 bool allowRepeat = false) -> ParseResult { 595 if (!llvm::is_contained(clauses, clause)) 596 return parser.emitError(parser.getCurrentLocation()) 597 << clauseKeyword << " is not a valid clause for the " << opName 598 << " operation"; 599 if (done[clause] && !allowRepeat) 600 return parser.emitError(parser.getCurrentLocation()) 601 << "at most one " << clauseKeyword << " clause can appear on the " 602 << opName << " operation"; 603 done[clause] = true; 604 return success(); 605 }; 606 607 while (succeeded(parser.parseOptionalKeyword(&clauseKeyword))) { 608 if (clauseKeyword == "if") { 609 if (checkAllowed(ifClause) || parser.parseLParen() || 610 parser.parseOperand(ifCond.first) || 611 parser.parseColonType(ifCond.second) || parser.parseRParen()) 612 return failure(); 613 clauseSegments[pos[ifClause]] = 1; 614 } else if (clauseKeyword == "num_threads") { 615 if (checkAllowed(numThreadsClause) || parser.parseLParen() || 616 parser.parseOperand(numThreads.first) || 617 parser.parseColonType(numThreads.second) || parser.parseRParen()) 618 return failure(); 619 clauseSegments[pos[numThreadsClause]] = 1; 620 } else if (clauseKeyword == "private") { 621 if (checkAllowed(privateClause) || 622 parseOperandAndTypeList(parser, privates, privateTypes)) 623 return failure(); 624 clauseSegments[pos[privateClause]] = privates.size(); 625 } else if (clauseKeyword == "firstprivate") { 626 if (checkAllowed(firstprivateClause) || 627 parseOperandAndTypeList(parser, firstprivates, firstprivateTypes)) 628 return failure(); 629 clauseSegments[pos[firstprivateClause]] = firstprivates.size(); 630 } else if (clauseKeyword == "lastprivate") { 631 if (checkAllowed(lastprivateClause) || 632 parseOperandAndTypeList(parser, lastprivates, lastprivateTypes)) 633 return failure(); 634 clauseSegments[pos[lastprivateClause]] = lastprivates.size(); 635 } else if (clauseKeyword == "shared") { 636 if (checkAllowed(sharedClause) || 637 parseOperandAndTypeList(parser, shareds, sharedTypes)) 638 return failure(); 639 clauseSegments[pos[sharedClause]] = shareds.size(); 640 } else if (clauseKeyword == "copyin") { 641 if (checkAllowed(copyinClause) || 642 parseOperandAndTypeList(parser, copyins, copyinTypes)) 643 return failure(); 644 clauseSegments[pos[copyinClause]] = copyins.size(); 645 } else if (clauseKeyword == "allocate") { 646 if (checkAllowed(allocateClause) || 647 parseAllocateAndAllocator(parser, allocates, allocateTypes, 648 allocators, allocatorTypes)) 649 return failure(); 650 clauseSegments[pos[allocateClause]] = allocates.size(); 651 clauseSegments[pos[allocateClause] + 1] = allocators.size(); 652 } else if (clauseKeyword == "default") { 653 StringRef defval; 654 if (checkAllowed(defaultClause) || parser.parseLParen() || 655 parser.parseKeyword(&defval) || parser.parseRParen()) 656 return failure(); 657 // The def prefix is required for the attribute as "private" is a keyword 658 // in C++. 659 auto attr = parser.getBuilder().getStringAttr("def" + defval); 660 result.addAttribute("default_val", attr); 661 } else if (clauseKeyword == "proc_bind") { 662 StringRef bind; 663 if (checkAllowed(procBindClause) || parser.parseLParen() || 664 parser.parseKeyword(&bind) || parser.parseRParen()) 665 return failure(); 666 auto attr = parser.getBuilder().getStringAttr(bind); 667 result.addAttribute("proc_bind_val", attr); 668 } else if (clauseKeyword == "reduction") { 669 if (checkAllowed(reductionClause) || 670 parseReductionVarList(parser, reductionSymbols, reductionVars, 671 reductionVarTypes)) 672 return failure(); 673 clauseSegments[pos[reductionClause]] = reductionVars.size(); 674 } else if (clauseKeyword == "nowait") { 675 if (checkAllowed(nowaitClause)) 676 return failure(); 677 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 678 result.addAttribute("nowait", attr); 679 } else if (clauseKeyword == "linear") { 680 if (checkAllowed(linearClause) || 681 parseLinearClause(parser, linears, linearTypes, linearSteps)) 682 return failure(); 683 clauseSegments[pos[linearClause]] = linears.size(); 684 clauseSegments[pos[linearClause] + 1] = linearSteps.size(); 685 } else if (clauseKeyword == "schedule") { 686 if (checkAllowed(scheduleClause) || 687 parseScheduleClause(parser, schedule, modifiers, scheduleChunkSize)) 688 return failure(); 689 if (scheduleChunkSize) { 690 clauseSegments[pos[scheduleClause]] = 1; 691 } 692 } else if (clauseKeyword == "collapse") { 693 auto type = parser.getBuilder().getI64Type(); 694 mlir::IntegerAttr attr; 695 if (checkAllowed(collapseClause) || parser.parseLParen() || 696 parser.parseAttribute(attr, type) || parser.parseRParen()) 697 return failure(); 698 result.addAttribute("collapse_val", attr); 699 } else if (clauseKeyword == "ordered") { 700 mlir::IntegerAttr attr; 701 if (checkAllowed(orderedClause)) 702 return failure(); 703 if (succeeded(parser.parseOptionalLParen())) { 704 auto type = parser.getBuilder().getI64Type(); 705 if (parser.parseAttribute(attr, type) || parser.parseRParen()) 706 return failure(); 707 } else { 708 // Use 0 to represent no ordered parameter was specified 709 attr = parser.getBuilder().getI64IntegerAttr(0); 710 } 711 result.addAttribute("ordered_val", attr); 712 } else if (clauseKeyword == "order") { 713 StringRef order; 714 if (checkAllowed(orderClause) || parser.parseLParen() || 715 parser.parseKeyword(&order) || parser.parseRParen()) 716 return failure(); 717 auto attr = parser.getBuilder().getStringAttr(order); 718 result.addAttribute("order_val", attr); 719 } else if (clauseKeyword == "memory_order") { 720 StringRef memoryOrder; 721 if (checkAllowed(memoryOrderClause) || parser.parseLParen() || 722 parser.parseKeyword(&memoryOrder) || parser.parseRParen()) 723 return failure(); 724 result.addAttribute("memory_order", 725 parser.getBuilder().getStringAttr(memoryOrder)); 726 } else if (clauseKeyword == "hint") { 727 IntegerAttr hint; 728 if (checkAllowed(hintClause) || 729 parseSynchronizationHint(parser, hint, false)) 730 return failure(); 731 result.addAttribute("hint", hint); 732 } else { 733 return parser.emitError(parser.getNameLoc()) 734 << clauseKeyword << " is not a valid clause"; 735 } 736 } 737 738 // Add if parameter. 739 if (done[ifClause] && clauseSegments[pos[ifClause]] && 740 failed( 741 parser.resolveOperand(ifCond.first, ifCond.second, result.operands))) 742 return failure(); 743 744 // Add num_threads parameter. 745 if (done[numThreadsClause] && clauseSegments[pos[numThreadsClause]] && 746 failed(parser.resolveOperand(numThreads.first, numThreads.second, 747 result.operands))) 748 return failure(); 749 750 // Add private parameters. 751 if (done[privateClause] && clauseSegments[pos[privateClause]] && 752 failed(parser.resolveOperands(privates, privateTypes, 753 privates[0].location, result.operands))) 754 return failure(); 755 756 // Add firstprivate parameters. 757 if (done[firstprivateClause] && clauseSegments[pos[firstprivateClause]] && 758 failed(parser.resolveOperands(firstprivates, firstprivateTypes, 759 firstprivates[0].location, 760 result.operands))) 761 return failure(); 762 763 // Add lastprivate parameters. 764 if (done[lastprivateClause] && clauseSegments[pos[lastprivateClause]] && 765 failed(parser.resolveOperands(lastprivates, lastprivateTypes, 766 lastprivates[0].location, result.operands))) 767 return failure(); 768 769 // Add shared parameters. 770 if (done[sharedClause] && clauseSegments[pos[sharedClause]] && 771 failed(parser.resolveOperands(shareds, sharedTypes, shareds[0].location, 772 result.operands))) 773 return failure(); 774 775 // Add copyin parameters. 776 if (done[copyinClause] && clauseSegments[pos[copyinClause]] && 777 failed(parser.resolveOperands(copyins, copyinTypes, copyins[0].location, 778 result.operands))) 779 return failure(); 780 781 // Add allocate parameters. 782 if (done[allocateClause] && clauseSegments[pos[allocateClause]] && 783 failed(parser.resolveOperands(allocates, allocateTypes, 784 allocates[0].location, result.operands))) 785 return failure(); 786 787 // Add allocator parameters. 788 if (done[allocateClause] && clauseSegments[pos[allocateClause] + 1] && 789 failed(parser.resolveOperands(allocators, allocatorTypes, 790 allocators[0].location, result.operands))) 791 return failure(); 792 793 // Add reduction parameters and symbols 794 if (done[reductionClause] && clauseSegments[pos[reductionClause]]) { 795 if (failed(parser.resolveOperands(reductionVars, reductionVarTypes, 796 parser.getNameLoc(), result.operands))) 797 return failure(); 798 799 SmallVector<Attribute> reductions(reductionSymbols.begin(), 800 reductionSymbols.end()); 801 result.addAttribute("reductions", 802 parser.getBuilder().getArrayAttr(reductions)); 803 } 804 805 // Add linear parameters 806 if (done[linearClause] && clauseSegments[pos[linearClause]]) { 807 auto linearStepType = parser.getBuilder().getI32Type(); 808 SmallVector<Type> linearStepTypes(linearSteps.size(), linearStepType); 809 if (failed(parser.resolveOperands(linears, linearTypes, linears[0].location, 810 result.operands)) || 811 failed(parser.resolveOperands(linearSteps, linearStepTypes, 812 linearSteps[0].location, 813 result.operands))) 814 return failure(); 815 } 816 817 // Add schedule parameters 818 if (done[scheduleClause] && !schedule.empty()) { 819 schedule[0] = llvm::toUpper(schedule[0]); 820 auto attr = parser.getBuilder().getStringAttr(schedule); 821 result.addAttribute("schedule_val", attr); 822 if (modifiers.size() > 0) { 823 auto mod = parser.getBuilder().getStringAttr(modifiers[0]); 824 result.addAttribute("schedule_modifier", mod); 825 } 826 if (scheduleChunkSize) { 827 auto chunkSizeType = parser.getBuilder().getI32Type(); 828 parser.resolveOperand(*scheduleChunkSize, chunkSizeType, result.operands); 829 } 830 } 831 832 segments.insert(segments.end(), clauseSegments.begin(), clauseSegments.end()); 833 834 return success(); 835 } 836 837 /// Parses a parallel operation. 838 /// 839 /// operation ::= `omp.parallel` clause-list 840 /// clause-list ::= clause | clause clause-list 841 /// clause ::= if | num-threads | private | firstprivate | shared | copyin | 842 /// allocate | default | proc-bind 843 /// 844 static ParseResult parseParallelOp(OpAsmParser &parser, 845 OperationState &result) { 846 SmallVector<ClauseType> clauses = { 847 ifClause, numThreadsClause, privateClause, 848 firstprivateClause, sharedClause, copyinClause, 849 allocateClause, defaultClause, procBindClause}; 850 851 SmallVector<int> segments; 852 853 if (failed(parseClauses(parser, result, clauses, segments))) 854 return failure(); 855 856 result.addAttribute("operand_segment_sizes", 857 parser.getBuilder().getI32VectorAttr(segments)); 858 859 Region *body = result.addRegion(); 860 SmallVector<OpAsmParser::OperandType> regionArgs; 861 SmallVector<Type> regionArgTypes; 862 if (parser.parseRegion(*body, regionArgs, regionArgTypes)) 863 return failure(); 864 return success(); 865 } 866 867 /// Parses an OpenMP Workshare Loop operation 868 /// 869 /// wsloop ::= `omp.wsloop` loop-control clause-list 870 /// loop-control ::= `(` ssa-id-list `)` `:` type `=` loop-bounds 871 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` inclusive? steps 872 /// steps := `step` `(`ssa-id-list`)` 873 /// clause-list ::= clause clause-list | empty 874 /// clause ::= private | firstprivate | lastprivate | linear | schedule | 875 // collapse | nowait | ordered | order | reduction 876 static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) { 877 878 // Parse an opening `(` followed by induction variables followed by `)` 879 SmallVector<OpAsmParser::OperandType> ivs; 880 if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1, 881 OpAsmParser::Delimiter::Paren)) 882 return failure(); 883 884 int numIVs = static_cast<int>(ivs.size()); 885 Type loopVarType; 886 if (parser.parseColonType(loopVarType)) 887 return failure(); 888 889 // Parse loop bounds. 890 SmallVector<OpAsmParser::OperandType> lower; 891 if (parser.parseEqual() || 892 parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) || 893 parser.resolveOperands(lower, loopVarType, result.operands)) 894 return failure(); 895 896 SmallVector<OpAsmParser::OperandType> upper; 897 if (parser.parseKeyword("to") || 898 parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) || 899 parser.resolveOperands(upper, loopVarType, result.operands)) 900 return failure(); 901 902 if (succeeded(parser.parseOptionalKeyword("inclusive"))) { 903 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 904 result.addAttribute("inclusive", attr); 905 } 906 907 // Parse step values. 908 SmallVector<OpAsmParser::OperandType> steps; 909 if (parser.parseKeyword("step") || 910 parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) || 911 parser.resolveOperands(steps, loopVarType, result.operands)) 912 return failure(); 913 914 SmallVector<ClauseType> clauses = { 915 privateClause, firstprivateClause, lastprivateClause, linearClause, 916 reductionClause, collapseClause, orderClause, orderedClause, 917 nowaitClause, scheduleClause}; 918 SmallVector<int> segments{numIVs, numIVs, numIVs}; 919 if (failed(parseClauses(parser, result, clauses, segments))) 920 return failure(); 921 922 result.addAttribute("operand_segment_sizes", 923 parser.getBuilder().getI32VectorAttr(segments)); 924 925 // Now parse the body. 926 Region *body = result.addRegion(); 927 SmallVector<Type> ivTypes(numIVs, loopVarType); 928 SmallVector<OpAsmParser::OperandType> blockArgs(ivs); 929 if (parser.parseRegion(*body, blockArgs, ivTypes)) 930 return failure(); 931 return success(); 932 } 933 934 static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) { 935 auto args = op.getRegion().front().getArguments(); 936 p << " (" << args << ") : " << args[0].getType() << " = (" << op.lowerBound() 937 << ") to (" << op.upperBound() << ") "; 938 if (op.inclusive()) { 939 p << "inclusive "; 940 } 941 p << "step (" << op.step() << ") "; 942 943 printDataVars(p, op.private_vars(), "private"); 944 printDataVars(p, op.firstprivate_vars(), "firstprivate"); 945 printDataVars(p, op.lastprivate_vars(), "lastprivate"); 946 947 if (op.linear_vars().size()) { 948 p << "linear"; 949 printLinearClause(p, op.linear_vars(), op.linear_step_vars()); 950 } 951 952 if (auto sched = op.schedule_val()) { 953 p << "schedule"; 954 printScheduleClause(p, sched.getValue(), op.schedule_modifier(), 955 op.schedule_chunk_var()); 956 } 957 958 if (auto collapse = op.collapse_val()) 959 p << "collapse(" << collapse << ") "; 960 961 if (op.nowait()) 962 p << "nowait "; 963 964 if (auto ordered = op.ordered_val()) 965 p << "ordered(" << ordered << ") "; 966 967 if (auto order = op.order_val()) 968 p << "order(" << order << ") "; 969 970 if (!op.reduction_vars().empty()) { 971 p << "reduction("; 972 printReductionVarList(p, op.reductions(), op.reduction_vars()); 973 } 974 975 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 976 } 977 978 //===----------------------------------------------------------------------===// 979 // ReductionOp 980 //===----------------------------------------------------------------------===// 981 982 static ParseResult parseAtomicReductionRegion(OpAsmParser &parser, 983 Region ®ion) { 984 if (parser.parseOptionalKeyword("atomic")) 985 return success(); 986 return parser.parseRegion(region); 987 } 988 989 static void printAtomicReductionRegion(OpAsmPrinter &printer, 990 ReductionDeclareOp op, Region ®ion) { 991 if (region.empty()) 992 return; 993 printer << "atomic "; 994 printer.printRegion(region); 995 } 996 997 static LogicalResult verifyReductionDeclareOp(ReductionDeclareOp op) { 998 if (op.initializerRegion().empty()) 999 return op.emitOpError() << "expects non-empty initializer region"; 1000 Block &initializerEntryBlock = op.initializerRegion().front(); 1001 if (initializerEntryBlock.getNumArguments() != 1 || 1002 initializerEntryBlock.getArgument(0).getType() != op.type()) { 1003 return op.emitOpError() << "expects initializer region with one argument " 1004 "of the reduction type"; 1005 } 1006 1007 for (YieldOp yieldOp : op.initializerRegion().getOps<YieldOp>()) { 1008 if (yieldOp.results().size() != 1 || 1009 yieldOp.results().getTypes()[0] != op.type()) 1010 return op.emitOpError() << "expects initializer region to yield a value " 1011 "of the reduction type"; 1012 } 1013 1014 if (op.reductionRegion().empty()) 1015 return op.emitOpError() << "expects non-empty reduction region"; 1016 Block &reductionEntryBlock = op.reductionRegion().front(); 1017 if (reductionEntryBlock.getNumArguments() != 2 || 1018 reductionEntryBlock.getArgumentTypes()[0] != 1019 reductionEntryBlock.getArgumentTypes()[1] || 1020 reductionEntryBlock.getArgumentTypes()[0] != op.type()) 1021 return op.emitOpError() << "expects reduction region with two arguments of " 1022 "the reduction type"; 1023 for (YieldOp yieldOp : op.reductionRegion().getOps<YieldOp>()) { 1024 if (yieldOp.results().size() != 1 || 1025 yieldOp.results().getTypes()[0] != op.type()) 1026 return op.emitOpError() << "expects reduction region to yield a value " 1027 "of the reduction type"; 1028 } 1029 1030 if (op.atomicReductionRegion().empty()) 1031 return success(); 1032 1033 Block &atomicReductionEntryBlock = op.atomicReductionRegion().front(); 1034 if (atomicReductionEntryBlock.getNumArguments() != 2 || 1035 atomicReductionEntryBlock.getArgumentTypes()[0] != 1036 atomicReductionEntryBlock.getArgumentTypes()[1]) 1037 return op.emitOpError() << "expects atomic reduction region with two " 1038 "arguments of the same type"; 1039 auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0] 1040 .dyn_cast<PointerLikeType>(); 1041 if (!ptrType || ptrType.getElementType() != op.type()) 1042 return op.emitOpError() << "expects atomic reduction region arguments to " 1043 "be accumulators containing the reduction type"; 1044 return success(); 1045 } 1046 1047 static LogicalResult verifyReductionOp(ReductionOp op) { 1048 // TODO: generalize this to an op interface when there is more than one op 1049 // that supports reductions. 1050 auto container = op->getParentOfType<WsLoopOp>(); 1051 for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i) 1052 if (container.reduction_vars()[i] == op.accumulator()) 1053 return success(); 1054 1055 return op.emitOpError() << "the accumulator is not used by the parent"; 1056 } 1057 1058 //===----------------------------------------------------------------------===// 1059 // WsLoopOp 1060 //===----------------------------------------------------------------------===// 1061 1062 void WsLoopOp::build(OpBuilder &builder, OperationState &state, 1063 ValueRange lowerBound, ValueRange upperBound, 1064 ValueRange step, ArrayRef<NamedAttribute> attributes) { 1065 build(builder, state, TypeRange(), lowerBound, upperBound, step, 1066 /*private_vars=*/ValueRange(), 1067 /*firstprivate_vars=*/ValueRange(), /*lastprivate_vars=*/ValueRange(), 1068 /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(), 1069 /*reduction_vars=*/ValueRange(), /*schedule_val=*/nullptr, 1070 /*schedule_chunk_var=*/nullptr, /*collapse_val=*/nullptr, 1071 /*nowait=*/nullptr, /*ordered_val=*/nullptr, /*order_val=*/nullptr, 1072 /*inclusive=*/nullptr, /*buildBody=*/false); 1073 state.addAttributes(attributes); 1074 } 1075 1076 void WsLoopOp::build(OpBuilder &, OperationState &state, TypeRange resultTypes, 1077 ValueRange operands, ArrayRef<NamedAttribute> attributes) { 1078 state.addOperands(operands); 1079 state.addAttributes(attributes); 1080 (void)state.addRegion(); 1081 assert(resultTypes.empty() && "mismatched number of return types"); 1082 state.addTypes(resultTypes); 1083 } 1084 1085 void WsLoopOp::build(OpBuilder &builder, OperationState &result, 1086 TypeRange typeRange, ValueRange lowerBounds, 1087 ValueRange upperBounds, ValueRange steps, 1088 ValueRange privateVars, ValueRange firstprivateVars, 1089 ValueRange lastprivateVars, ValueRange linearVars, 1090 ValueRange linearStepVars, ValueRange reductionVars, 1091 StringAttr scheduleVal, Value scheduleChunkVar, 1092 IntegerAttr collapseVal, UnitAttr nowait, 1093 IntegerAttr orderedVal, StringAttr orderVal, 1094 UnitAttr inclusive, bool buildBody) { 1095 result.addOperands(lowerBounds); 1096 result.addOperands(upperBounds); 1097 result.addOperands(steps); 1098 result.addOperands(privateVars); 1099 result.addOperands(firstprivateVars); 1100 result.addOperands(linearVars); 1101 result.addOperands(linearStepVars); 1102 if (scheduleChunkVar) 1103 result.addOperands(scheduleChunkVar); 1104 1105 if (scheduleVal) 1106 result.addAttribute("schedule_val", scheduleVal); 1107 if (collapseVal) 1108 result.addAttribute("collapse_val", collapseVal); 1109 if (nowait) 1110 result.addAttribute("nowait", nowait); 1111 if (orderedVal) 1112 result.addAttribute("ordered_val", orderedVal); 1113 if (orderVal) 1114 result.addAttribute("order", orderVal); 1115 if (inclusive) 1116 result.addAttribute("inclusive", inclusive); 1117 result.addAttribute( 1118 WsLoopOp::getOperandSegmentSizeAttr(), 1119 builder.getI32VectorAttr( 1120 {static_cast<int32_t>(lowerBounds.size()), 1121 static_cast<int32_t>(upperBounds.size()), 1122 static_cast<int32_t>(steps.size()), 1123 static_cast<int32_t>(privateVars.size()), 1124 static_cast<int32_t>(firstprivateVars.size()), 1125 static_cast<int32_t>(lastprivateVars.size()), 1126 static_cast<int32_t>(linearVars.size()), 1127 static_cast<int32_t>(linearStepVars.size()), 1128 static_cast<int32_t>(reductionVars.size()), 1129 static_cast<int32_t>(scheduleChunkVar != nullptr ? 1 : 0)})); 1130 1131 Region *bodyRegion = result.addRegion(); 1132 if (buildBody) { 1133 OpBuilder::InsertionGuard guard(builder); 1134 unsigned numIVs = steps.size(); 1135 SmallVector<Type, 8> argTypes(numIVs, steps.getType().front()); 1136 builder.createBlock(bodyRegion, {}, argTypes); 1137 } 1138 } 1139 1140 static LogicalResult verifyWsLoopOp(WsLoopOp op) { 1141 return verifyReductionVarList(op, op.reductions(), op.reduction_vars()); 1142 } 1143 1144 //===----------------------------------------------------------------------===// 1145 // Verifier for critical construct (2.17.1) 1146 //===----------------------------------------------------------------------===// 1147 1148 static LogicalResult verifyCriticalDeclareOp(CriticalDeclareOp op) { 1149 return verifySynchronizationHint(op, op.hint()); 1150 } 1151 1152 static LogicalResult verifyCriticalOp(CriticalOp op) { 1153 1154 if (op.nameAttr()) { 1155 auto symbolRef = op.nameAttr().cast<SymbolRefAttr>(); 1156 auto decl = 1157 SymbolTable::lookupNearestSymbolFrom<CriticalDeclareOp>(op, symbolRef); 1158 if (!decl) { 1159 return op.emitOpError() << "expected symbol reference " << symbolRef 1160 << " to point to a critical declaration"; 1161 } 1162 } 1163 1164 return success(); 1165 } 1166 1167 //===----------------------------------------------------------------------===// 1168 // Verifier for ordered construct 1169 //===----------------------------------------------------------------------===// 1170 1171 static LogicalResult verifyOrderedOp(OrderedOp op) { 1172 auto container = op->getParentOfType<WsLoopOp>(); 1173 if (!container || !container.ordered_valAttr() || 1174 container.ordered_valAttr().getInt() == 0) 1175 return op.emitOpError() << "ordered depend directive must be closely " 1176 << "nested inside a worksharing-loop with ordered " 1177 << "clause with parameter present"; 1178 1179 if (container.ordered_valAttr().getInt() != 1180 (int64_t)op.num_loops_val().getValue()) 1181 return op.emitOpError() << "number of variables in depend clause does not " 1182 << "match number of iteration variables in the " 1183 << "doacross loop"; 1184 1185 return success(); 1186 } 1187 1188 static LogicalResult verifyOrderedRegionOp(OrderedRegionOp op) { 1189 // TODO: The code generation for ordered simd directive is not supported yet. 1190 if (op.simd()) 1191 return failure(); 1192 1193 if (auto container = op->getParentOfType<WsLoopOp>()) { 1194 if (!container.ordered_valAttr() || 1195 container.ordered_valAttr().getInt() != 0) 1196 return op.emitOpError() << "ordered region must be closely nested inside " 1197 << "a worksharing-loop region with an ordered " 1198 << "clause without parameter present"; 1199 } 1200 1201 return success(); 1202 } 1203 1204 //===----------------------------------------------------------------------===// 1205 // AtomicReadOp 1206 //===----------------------------------------------------------------------===// 1207 1208 /// Parser for AtomicReadOp 1209 /// 1210 /// operation ::= `omp.atomic.read` atomic-clause-list address `->` result-type 1211 /// address ::= operand `:` type 1212 static ParseResult parseAtomicReadOp(OpAsmParser &parser, 1213 OperationState &result) { 1214 OpAsmParser::OperandType address; 1215 Type addressType; 1216 SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause}; 1217 SmallVector<int> segments; 1218 1219 if (parser.parseOperand(address) || 1220 parseClauses(parser, result, clauses, segments) || 1221 parser.parseColonType(addressType) || 1222 parser.resolveOperand(address, addressType, result.operands)) 1223 return failure(); 1224 1225 SmallVector<Type> resultType; 1226 if (parser.parseArrowTypeList(resultType)) 1227 return failure(); 1228 result.addTypes(resultType); 1229 return success(); 1230 } 1231 1232 /// Printer for AtomicReadOp 1233 static void printAtomicReadOp(OpAsmPrinter &p, AtomicReadOp op) { 1234 p << " " << op.address() << " "; 1235 if (op.memory_order()) 1236 p << "memory_order(" << op.memory_order().getValue() << ") "; 1237 if (op.hintAttr()) 1238 printSynchronizationHint(p << " ", op, op.hintAttr()); 1239 p << ": " << op.address().getType() << " -> " << op.getType(); 1240 return; 1241 } 1242 1243 /// Verifier for AtomicReadOp 1244 static LogicalResult verifyAtomicReadOp(AtomicReadOp op) { 1245 if (op.memory_order()) { 1246 StringRef memOrder = op.memory_order().getValue(); 1247 if (memOrder.equals("acq_rel") || memOrder.equals("release")) 1248 return op.emitError( 1249 "memory-order must not be acq_rel or release for atomic reads"); 1250 } 1251 return verifySynchronizationHint(op, op.hint()); 1252 } 1253 1254 //===----------------------------------------------------------------------===// 1255 // AtomicWriteOp 1256 //===----------------------------------------------------------------------===// 1257 1258 /// Parser for AtomicWriteOp 1259 /// 1260 /// operation ::= `omp.atomic.write` atomic-clause-list operands 1261 /// operands ::= address `,` value 1262 /// address ::= operand `:` type 1263 /// value ::= operand `:` type 1264 static ParseResult parseAtomicWriteOp(OpAsmParser &parser, 1265 OperationState &result) { 1266 OpAsmParser::OperandType address, value; 1267 Type addrType, valueType; 1268 SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause}; 1269 SmallVector<int> segments; 1270 1271 if (parser.parseOperand(address) || parser.parseComma() || 1272 parser.parseOperand(value) || 1273 parseClauses(parser, result, clauses, segments) || 1274 parser.parseColonType(addrType) || parser.parseComma() || 1275 parser.parseType(valueType) || 1276 parser.resolveOperand(address, addrType, result.operands) || 1277 parser.resolveOperand(value, valueType, result.operands)) 1278 return failure(); 1279 return success(); 1280 } 1281 1282 /// Printer for AtomicWriteOp 1283 static void printAtomicWriteOp(OpAsmPrinter &p, AtomicWriteOp op) { 1284 p << " " << op.address() << ", " << op.value() << " "; 1285 if (op.memory_order()) 1286 p << "memory_order(" << op.memory_order() << ") "; 1287 if (op.hintAttr()) 1288 printSynchronizationHint(p, op, op.hintAttr()); 1289 p << ": " << op.address().getType() << ", " << op.value().getType(); 1290 return; 1291 } 1292 1293 /// Verifier for AtomicWriteOp 1294 static LogicalResult verifyAtomicWriteOp(AtomicWriteOp op) { 1295 if (op.memory_order()) { 1296 StringRef memoryOrder = op.memory_order().getValue(); 1297 if (memoryOrder.equals("acq_rel") || memoryOrder.equals("acquire")) 1298 return op.emitError( 1299 "memory-order must not be acq_rel or acquire for atomic writes"); 1300 } 1301 return verifySynchronizationHint(op, op.hint()); 1302 } 1303 1304 #define GET_OP_CLASSES 1305 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 1306