1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the OpenMP dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/IR/Attributes.h" 17 #include "mlir/IR/OpImplementation.h" 18 #include "mlir/IR/OperationSupport.h" 19 20 #include "llvm/ADT/BitVector.h" 21 #include "llvm/ADT/SmallString.h" 22 #include "llvm/ADT/StringExtras.h" 23 #include "llvm/ADT/StringRef.h" 24 #include "llvm/ADT/StringSwitch.h" 25 #include <cstddef> 26 27 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc" 28 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc" 29 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc" 30 31 using namespace mlir; 32 using namespace mlir::omp; 33 34 namespace { 35 /// Model for pointer-like types that already provide a `getElementType` method. 36 template <typename T> 37 struct PointerLikeModel 38 : public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> { 39 Type getElementType(Type pointer) const { 40 return pointer.cast<T>().getElementType(); 41 } 42 }; 43 } // end namespace 44 45 void OpenMPDialect::initialize() { 46 addOperations< 47 #define GET_OP_LIST 48 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 49 >(); 50 51 LLVM::LLVMPointerType::attachInterface< 52 PointerLikeModel<LLVM::LLVMPointerType>>(*getContext()); 53 MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext()); 54 } 55 56 //===----------------------------------------------------------------------===// 57 // ParallelOp 58 //===----------------------------------------------------------------------===// 59 60 void ParallelOp::build(OpBuilder &builder, OperationState &state, 61 ArrayRef<NamedAttribute> attributes) { 62 ParallelOp::build( 63 builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr, 64 /*default_val=*/nullptr, /*private_vars=*/ValueRange(), 65 /*firstprivate_vars=*/ValueRange(), /*shared_vars=*/ValueRange(), 66 /*copyin_vars=*/ValueRange(), /*allocate_vars=*/ValueRange(), 67 /*allocators_vars=*/ValueRange(), /*proc_bind_val=*/nullptr); 68 state.addAttributes(attributes); 69 } 70 71 //===----------------------------------------------------------------------===// 72 // Parser and printer for Operand and type list 73 //===----------------------------------------------------------------------===// 74 75 /// Parse a list of operands with types. 76 /// 77 /// operand-and-type-list ::= `(` ssa-id-and-type-list `)` 78 /// ssa-id-and-type-list ::= ssa-id-and-type | 79 /// ssa-id-and-type `,` ssa-id-and-type-list 80 /// ssa-id-and-type ::= ssa-id `:` type 81 static ParseResult 82 parseOperandAndTypeList(OpAsmParser &parser, 83 SmallVectorImpl<OpAsmParser::OperandType> &operands, 84 SmallVectorImpl<Type> &types) { 85 return parser.parseCommaSeparatedList( 86 OpAsmParser::Delimiter::Paren, [&]() -> ParseResult { 87 OpAsmParser::OperandType operand; 88 Type type; 89 if (parser.parseOperand(operand) || parser.parseColonType(type)) 90 return failure(); 91 operands.push_back(operand); 92 types.push_back(type); 93 return success(); 94 }); 95 } 96 97 /// Print an operand and type list with parentheses 98 static void printOperandAndTypeList(OpAsmPrinter &p, OperandRange operands) { 99 p << "("; 100 llvm::interleaveComma( 101 operands, p, [&](const Value &v) { p << v << " : " << v.getType(); }); 102 p << ") "; 103 } 104 105 /// Print data variables corresponding to a data-sharing clause `name` 106 static void printDataVars(OpAsmPrinter &p, OperandRange operands, 107 StringRef name) { 108 if (operands.size()) { 109 p << name; 110 printOperandAndTypeList(p, operands); 111 } 112 } 113 114 //===----------------------------------------------------------------------===// 115 // Parser and printer for Allocate Clause 116 //===----------------------------------------------------------------------===// 117 118 /// Parse an allocate clause with allocators and a list of operands with types. 119 /// 120 /// allocate ::= `allocate` `(` allocate-operand-list `)` 121 /// allocate-operand-list :: = allocate-operand | 122 /// allocator-operand `,` allocate-operand-list 123 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type 124 /// ssa-id-and-type ::= ssa-id `:` type 125 static ParseResult parseAllocateAndAllocator( 126 OpAsmParser &parser, 127 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate, 128 SmallVectorImpl<Type> &typesAllocate, 129 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator, 130 SmallVectorImpl<Type> &typesAllocator) { 131 132 return parser.parseCommaSeparatedList( 133 OpAsmParser::Delimiter::Paren, [&]() -> ParseResult { 134 OpAsmParser::OperandType operand; 135 Type type; 136 if (parser.parseOperand(operand) || parser.parseColonType(type)) 137 return failure(); 138 operandsAllocator.push_back(operand); 139 typesAllocator.push_back(type); 140 if (parser.parseArrow()) 141 return failure(); 142 if (parser.parseOperand(operand) || parser.parseColonType(type)) 143 return failure(); 144 145 operandsAllocate.push_back(operand); 146 typesAllocate.push_back(type); 147 return success(); 148 }); 149 } 150 151 /// Print allocate clause 152 static void printAllocateAndAllocator(OpAsmPrinter &p, 153 OperandRange varsAllocate, 154 OperandRange varsAllocator) { 155 p << "allocate("; 156 for (unsigned i = 0; i < varsAllocate.size(); ++i) { 157 std::string separator = i == varsAllocate.size() - 1 ? ") " : ", "; 158 p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> "; 159 p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator; 160 } 161 } 162 163 static LogicalResult verifyParallelOp(ParallelOp op) { 164 if (op.allocate_vars().size() != op.allocators_vars().size()) 165 return op.emitError( 166 "expected equal sizes for allocate and allocator variables"); 167 return success(); 168 } 169 170 static void printParallelOp(OpAsmPrinter &p, ParallelOp op) { 171 p << " "; 172 if (auto ifCond = op.if_expr_var()) 173 p << "if(" << ifCond << " : " << ifCond.getType() << ") "; 174 175 if (auto threads = op.num_threads_var()) 176 p << "num_threads(" << threads << " : " << threads.getType() << ") "; 177 178 printDataVars(p, op.private_vars(), "private"); 179 printDataVars(p, op.firstprivate_vars(), "firstprivate"); 180 printDataVars(p, op.shared_vars(), "shared"); 181 printDataVars(p, op.copyin_vars(), "copyin"); 182 183 if (!op.allocate_vars().empty()) 184 printAllocateAndAllocator(p, op.allocate_vars(), op.allocators_vars()); 185 186 if (auto def = op.default_val()) 187 p << "default(" << def->drop_front(3) << ") "; 188 189 if (auto bind = op.proc_bind_val()) 190 p << "proc_bind(" << bind << ") "; 191 192 p.printRegion(op.getRegion()); 193 } 194 195 //===----------------------------------------------------------------------===// 196 // Parser and printer for Linear Clause 197 //===----------------------------------------------------------------------===// 198 199 /// linear ::= `linear` `(` linear-list `)` 200 /// linear-list := linear-val | linear-val linear-list 201 /// linear-val := ssa-id-and-type `=` ssa-id-and-type 202 static ParseResult 203 parseLinearClause(OpAsmParser &parser, 204 SmallVectorImpl<OpAsmParser::OperandType> &vars, 205 SmallVectorImpl<Type> &types, 206 SmallVectorImpl<OpAsmParser::OperandType> &stepVars) { 207 if (parser.parseLParen()) 208 return failure(); 209 210 do { 211 OpAsmParser::OperandType var; 212 Type type; 213 OpAsmParser::OperandType stepVar; 214 if (parser.parseOperand(var) || parser.parseEqual() || 215 parser.parseOperand(stepVar) || parser.parseColonType(type)) 216 return failure(); 217 218 vars.push_back(var); 219 types.push_back(type); 220 stepVars.push_back(stepVar); 221 } while (succeeded(parser.parseOptionalComma())); 222 223 if (parser.parseRParen()) 224 return failure(); 225 226 return success(); 227 } 228 229 /// Print Linear Clause 230 static void printLinearClause(OpAsmPrinter &p, OperandRange linearVars, 231 OperandRange linearStepVars) { 232 size_t linearVarsSize = linearVars.size(); 233 p << "linear("; 234 for (unsigned i = 0; i < linearVarsSize; ++i) { 235 std::string separator = i == linearVarsSize - 1 ? ") " : ", "; 236 p << linearVars[i]; 237 if (linearStepVars.size() > i) 238 p << " = " << linearStepVars[i]; 239 p << " : " << linearVars[i].getType() << separator; 240 } 241 } 242 243 //===----------------------------------------------------------------------===// 244 // Parser and printer for Schedule Clause 245 //===----------------------------------------------------------------------===// 246 247 /// schedule ::= `schedule` `(` sched-list `)` 248 /// sched-list ::= sched-val | sched-val sched-list 249 /// sched-val ::= sched-with-chunk | sched-wo-chunk 250 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)? 251 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided` 252 /// sched-wo-chunk ::= `auto` | `runtime` 253 static ParseResult 254 parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule, 255 SmallVectorImpl<SmallString<12>> &modifiers, 256 Optional<OpAsmParser::OperandType> &chunkSize) { 257 if (parser.parseLParen()) 258 return failure(); 259 260 StringRef keyword; 261 if (parser.parseKeyword(&keyword)) 262 return failure(); 263 264 schedule = keyword; 265 if (keyword == "static" || keyword == "dynamic" || keyword == "guided") { 266 if (succeeded(parser.parseOptionalEqual())) { 267 chunkSize = OpAsmParser::OperandType{}; 268 if (parser.parseOperand(*chunkSize)) 269 return failure(); 270 } else { 271 chunkSize = llvm::NoneType::None; 272 } 273 } else if (keyword == "auto" || keyword == "runtime") { 274 chunkSize = llvm::NoneType::None; 275 } else { 276 return parser.emitError(parser.getNameLoc()) << " expected schedule kind"; 277 } 278 279 // If there is a comma, we have one or more modifiers.. 280 if (succeeded(parser.parseOptionalComma())) { 281 StringRef mod; 282 if (parser.parseKeyword(&mod)) 283 return failure(); 284 modifiers.push_back(mod); 285 } 286 287 if (parser.parseRParen()) 288 return failure(); 289 290 return success(); 291 } 292 293 /// Print schedule clause 294 static void printScheduleClause(OpAsmPrinter &p, StringRef &sched, 295 llvm::Optional<StringRef> modifier, 296 Value scheduleChunkVar) { 297 std::string schedLower = sched.lower(); 298 p << "schedule(" << schedLower; 299 if (scheduleChunkVar) 300 p << " = " << scheduleChunkVar; 301 if (modifier && modifier.getValue() != "none") 302 p << ", " << modifier; 303 p << ") "; 304 } 305 306 //===----------------------------------------------------------------------===// 307 // Parser, printer and verifier for ReductionVarList 308 //===----------------------------------------------------------------------===// 309 310 /// reduction ::= `reduction` `(` reduction-entry-list `)` 311 /// reduction-entry-list ::= reduction-entry 312 /// | reduction-entry-list `,` reduction-entry 313 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type 314 static ParseResult 315 parseReductionVarList(OpAsmParser &parser, 316 SmallVectorImpl<SymbolRefAttr> &symbols, 317 SmallVectorImpl<OpAsmParser::OperandType> &operands, 318 SmallVectorImpl<Type> &types) { 319 if (failed(parser.parseLParen())) 320 return failure(); 321 322 do { 323 if (parser.parseAttribute(symbols.emplace_back()) || parser.parseArrow() || 324 parser.parseOperand(operands.emplace_back()) || 325 parser.parseColonType(types.emplace_back())) 326 return failure(); 327 } while (succeeded(parser.parseOptionalComma())); 328 return parser.parseRParen(); 329 } 330 331 /// Print Reduction clause 332 static void printReductionVarList(OpAsmPrinter &p, 333 Optional<ArrayAttr> reductions, 334 OperandRange reduction_vars) { 335 p << "reduction("; 336 for (unsigned i = 0, e = reductions->size(); i < e; ++i) { 337 if (i != 0) 338 p << ", "; 339 p << (*reductions)[i] << " -> " << reduction_vars[i] << " : " 340 << reduction_vars[i].getType(); 341 } 342 p << ") "; 343 } 344 345 /// Verifies Reduction Clause 346 static LogicalResult verifyReductionVarList(Operation *op, 347 Optional<ArrayAttr> reductions, 348 OperandRange reduction_vars) { 349 if (reduction_vars.size() != 0) { 350 if (!reductions || reductions->size() != reduction_vars.size()) 351 return op->emitOpError() 352 << "expected as many reduction symbol references " 353 "as reduction variables"; 354 } else { 355 if (reductions) 356 return op->emitOpError() << "unexpected reduction symbol references"; 357 return success(); 358 } 359 360 DenseSet<Value> accumulators; 361 for (auto args : llvm::zip(reduction_vars, *reductions)) { 362 Value accum = std::get<0>(args); 363 364 if (!accumulators.insert(accum).second) 365 return op->emitOpError() << "accumulator variable used more than once"; 366 367 Type varType = accum.getType().cast<PointerLikeType>(); 368 auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>(); 369 auto decl = 370 SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef); 371 if (!decl) 372 return op->emitOpError() << "expected symbol reference " << symbolRef 373 << " to point to a reduction declaration"; 374 375 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType) 376 return op->emitOpError() 377 << "expected accumulator (" << varType 378 << ") to be the same type as reduction declaration (" 379 << decl.getAccumulatorType() << ")"; 380 } 381 382 return success(); 383 } 384 385 //===----------------------------------------------------------------------===// 386 // Parser, printer and verifier for Synchronization Hint (2.17.12) 387 //===----------------------------------------------------------------------===// 388 389 /// Parses a Synchronization Hint clause. The value of hint is an integer 390 /// which is a combination of different hints from `omp_sync_hint_t`. 391 /// 392 /// hint-clause = `hint` `(` hint-value `)` 393 static ParseResult parseSynchronizationHint(OpAsmParser &parser, 394 IntegerAttr &hintAttr, 395 bool parseKeyword = true) { 396 if (parseKeyword && failed(parser.parseOptionalKeyword("hint"))) { 397 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0); 398 return success(); 399 } 400 401 if (failed(parser.parseLParen())) 402 return failure(); 403 StringRef hintKeyword; 404 int64_t hint = 0; 405 do { 406 if (failed(parser.parseKeyword(&hintKeyword))) 407 return failure(); 408 if (hintKeyword == "uncontended") 409 hint |= 1; 410 else if (hintKeyword == "contended") 411 hint |= 2; 412 else if (hintKeyword == "nonspeculative") 413 hint |= 4; 414 else if (hintKeyword == "speculative") 415 hint |= 8; 416 else 417 return parser.emitError(parser.getCurrentLocation()) 418 << hintKeyword << " is not a valid hint"; 419 } while (succeeded(parser.parseOptionalComma())); 420 if (failed(parser.parseRParen())) 421 return failure(); 422 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint); 423 return success(); 424 } 425 426 /// Prints a Synchronization Hint clause 427 static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, 428 IntegerAttr hintAttr) { 429 int64_t hint = hintAttr.getInt(); 430 431 if (hint == 0) 432 return; 433 434 // Helper function to get n-th bit from the right end of `value` 435 auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; 436 437 bool uncontended = bitn(hint, 0); 438 bool contended = bitn(hint, 1); 439 bool nonspeculative = bitn(hint, 2); 440 bool speculative = bitn(hint, 3); 441 442 SmallVector<StringRef> hints; 443 if (uncontended) 444 hints.push_back("uncontended"); 445 if (contended) 446 hints.push_back("contended"); 447 if (nonspeculative) 448 hints.push_back("nonspeculative"); 449 if (speculative) 450 hints.push_back("speculative"); 451 452 p << "hint("; 453 llvm::interleaveComma(hints, p); 454 p << ") "; 455 } 456 457 /// Verifies a synchronization hint clause 458 static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) { 459 460 // Helper function to get n-th bit from the right end of `value` 461 auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; 462 463 bool uncontended = bitn(hint, 0); 464 bool contended = bitn(hint, 1); 465 bool nonspeculative = bitn(hint, 2); 466 bool speculative = bitn(hint, 3); 467 468 if (uncontended && contended) 469 return op->emitOpError() << "the hints omp_sync_hint_uncontended and " 470 "omp_sync_hint_contended cannot be combined"; 471 if (nonspeculative && speculative) 472 return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and " 473 "omp_sync_hint_speculative cannot be combined."; 474 return success(); 475 } 476 477 enum ClauseType { 478 ifClause, 479 numThreadsClause, 480 privateClause, 481 firstprivateClause, 482 lastprivateClause, 483 sharedClause, 484 copyinClause, 485 allocateClause, 486 defaultClause, 487 procBindClause, 488 reductionClause, 489 nowaitClause, 490 linearClause, 491 scheduleClause, 492 collapseClause, 493 orderClause, 494 orderedClause, 495 memoryOrderClause, 496 hintClause, 497 COUNT 498 }; 499 500 //===----------------------------------------------------------------------===// 501 // Parser for Clause List 502 //===----------------------------------------------------------------------===// 503 504 /// Parse a list of clauses. The clauses can appear in any order, but their 505 /// operand segment indices are in the same order that they are passed in the 506 /// `clauses` list. The operand segments are added over the prevSegments 507 508 /// clause-list ::= clause clause-list | empty 509 /// clause ::= if | num-threads | private | firstprivate | lastprivate | 510 /// shared | copyin | allocate | default | proc-bind | reduction | 511 /// nowait | linear | schedule | collapse | order | ordered | 512 /// inclusive 513 /// if ::= `if` `(` ssa-id-and-type `)` 514 /// num-threads ::= `num_threads` `(` ssa-id-and-type `)` 515 /// private ::= `private` operand-and-type-list 516 /// firstprivate ::= `firstprivate` operand-and-type-list 517 /// lastprivate ::= `lastprivate` operand-and-type-list 518 /// shared ::= `shared` operand-and-type-list 519 /// copyin ::= `copyin` operand-and-type-list 520 /// allocate ::= `allocate` `(` allocate-operand-list `)` 521 /// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`) 522 /// proc-bind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)` 523 /// reduction ::= `reduction` `(` reduction-entry-list `)` 524 /// nowait ::= `nowait` 525 /// linear ::= `linear` `(` linear-list `)` 526 /// schedule ::= `schedule` `(` sched-list `)` 527 /// collapse ::= `collapse` `(` ssa-id-and-type `)` 528 /// order ::= `order` `(` `concurrent` `)` 529 /// ordered ::= `ordered` `(` ssa-id-and-type `)` 530 /// inclusive ::= `inclusive` 531 /// 532 /// Note that each clause can only appear once in the clase-list. 533 static ParseResult parseClauses(OpAsmParser &parser, OperationState &result, 534 SmallVectorImpl<ClauseType> &clauses, 535 SmallVectorImpl<int> &segments) { 536 537 // Check done[clause] to see if it has been parsed already 538 llvm::BitVector done(ClauseType::COUNT, false); 539 540 // See pos[clause] to get position of clause in operand segments 541 SmallVector<int> pos(ClauseType::COUNT, -1); 542 543 // Stores the last parsed clause keyword 544 StringRef clauseKeyword; 545 StringRef opName = result.name.getStringRef(); 546 547 // Containers for storing operands, types and attributes for various clauses 548 std::pair<OpAsmParser::OperandType, Type> ifCond; 549 std::pair<OpAsmParser::OperandType, Type> numThreads; 550 551 SmallVector<OpAsmParser::OperandType> privates, firstprivates, lastprivates, 552 shareds, copyins; 553 SmallVector<Type> privateTypes, firstprivateTypes, lastprivateTypes, 554 sharedTypes, copyinTypes; 555 556 SmallVector<OpAsmParser::OperandType> allocates, allocators; 557 SmallVector<Type> allocateTypes, allocatorTypes; 558 559 SmallVector<SymbolRefAttr> reductionSymbols; 560 SmallVector<OpAsmParser::OperandType> reductionVars; 561 SmallVector<Type> reductionVarTypes; 562 563 SmallVector<OpAsmParser::OperandType> linears; 564 SmallVector<Type> linearTypes; 565 SmallVector<OpAsmParser::OperandType> linearSteps; 566 567 SmallString<8> schedule; 568 SmallVector<SmallString<12>> modifiers; 569 Optional<OpAsmParser::OperandType> scheduleChunkSize; 570 571 // Compute the position of clauses in operand segments 572 int currPos = 0; 573 for (ClauseType clause : clauses) { 574 575 // Skip the following clauses - they do not take any position in operand 576 // segments 577 if (clause == defaultClause || clause == procBindClause || 578 clause == nowaitClause || clause == collapseClause || 579 clause == orderClause || clause == orderedClause) 580 continue; 581 582 pos[clause] = currPos++; 583 584 // For the following clauses, two positions are reserved in the operand 585 // segments 586 if (clause == allocateClause || clause == linearClause) 587 currPos++; 588 } 589 590 SmallVector<int> clauseSegments(currPos); 591 592 // Helper function to check if a clause is allowed/repeated or not 593 auto checkAllowed = [&](ClauseType clause, 594 bool allowRepeat = false) -> ParseResult { 595 if (!llvm::is_contained(clauses, clause)) 596 return parser.emitError(parser.getCurrentLocation()) 597 << clauseKeyword << " is not a valid clause for the " << opName 598 << " operation"; 599 if (done[clause] && !allowRepeat) 600 return parser.emitError(parser.getCurrentLocation()) 601 << "at most one " << clauseKeyword << " clause can appear on the " 602 << opName << " operation"; 603 done[clause] = true; 604 return success(); 605 }; 606 607 while (succeeded(parser.parseOptionalKeyword(&clauseKeyword))) { 608 if (clauseKeyword == "if") { 609 if (checkAllowed(ifClause) || parser.parseLParen() || 610 parser.parseOperand(ifCond.first) || 611 parser.parseColonType(ifCond.second) || parser.parseRParen()) 612 return failure(); 613 clauseSegments[pos[ifClause]] = 1; 614 } else if (clauseKeyword == "num_threads") { 615 if (checkAllowed(numThreadsClause) || parser.parseLParen() || 616 parser.parseOperand(numThreads.first) || 617 parser.parseColonType(numThreads.second) || parser.parseRParen()) 618 return failure(); 619 clauseSegments[pos[numThreadsClause]] = 1; 620 } else if (clauseKeyword == "private") { 621 if (checkAllowed(privateClause) || 622 parseOperandAndTypeList(parser, privates, privateTypes)) 623 return failure(); 624 clauseSegments[pos[privateClause]] = privates.size(); 625 } else if (clauseKeyword == "firstprivate") { 626 if (checkAllowed(firstprivateClause) || 627 parseOperandAndTypeList(parser, firstprivates, firstprivateTypes)) 628 return failure(); 629 clauseSegments[pos[firstprivateClause]] = firstprivates.size(); 630 } else if (clauseKeyword == "lastprivate") { 631 if (checkAllowed(lastprivateClause) || 632 parseOperandAndTypeList(parser, lastprivates, lastprivateTypes)) 633 return failure(); 634 clauseSegments[pos[lastprivateClause]] = lastprivates.size(); 635 } else if (clauseKeyword == "shared") { 636 if (checkAllowed(sharedClause) || 637 parseOperandAndTypeList(parser, shareds, sharedTypes)) 638 return failure(); 639 clauseSegments[pos[sharedClause]] = shareds.size(); 640 } else if (clauseKeyword == "copyin") { 641 if (checkAllowed(copyinClause) || 642 parseOperandAndTypeList(parser, copyins, copyinTypes)) 643 return failure(); 644 clauseSegments[pos[copyinClause]] = copyins.size(); 645 } else if (clauseKeyword == "allocate") { 646 if (checkAllowed(allocateClause) || 647 parseAllocateAndAllocator(parser, allocates, allocateTypes, 648 allocators, allocatorTypes)) 649 return failure(); 650 clauseSegments[pos[allocateClause]] = allocates.size(); 651 clauseSegments[pos[allocateClause] + 1] = allocators.size(); 652 } else if (clauseKeyword == "default") { 653 StringRef defval; 654 if (checkAllowed(defaultClause) || parser.parseLParen() || 655 parser.parseKeyword(&defval) || parser.parseRParen()) 656 return failure(); 657 // The def prefix is required for the attribute as "private" is a keyword 658 // in C++. 659 auto attr = parser.getBuilder().getStringAttr("def" + defval); 660 result.addAttribute("default_val", attr); 661 } else if (clauseKeyword == "proc_bind") { 662 StringRef bind; 663 if (checkAllowed(procBindClause) || parser.parseLParen() || 664 parser.parseKeyword(&bind) || parser.parseRParen()) 665 return failure(); 666 auto attr = parser.getBuilder().getStringAttr(bind); 667 result.addAttribute("proc_bind_val", attr); 668 } else if (clauseKeyword == "reduction") { 669 if (checkAllowed(reductionClause) || 670 parseReductionVarList(parser, reductionSymbols, reductionVars, 671 reductionVarTypes)) 672 return failure(); 673 clauseSegments[pos[reductionClause]] = reductionVars.size(); 674 } else if (clauseKeyword == "nowait") { 675 if (checkAllowed(nowaitClause)) 676 return failure(); 677 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 678 result.addAttribute("nowait", attr); 679 } else if (clauseKeyword == "linear") { 680 if (checkAllowed(linearClause) || 681 parseLinearClause(parser, linears, linearTypes, linearSteps)) 682 return failure(); 683 clauseSegments[pos[linearClause]] = linears.size(); 684 clauseSegments[pos[linearClause] + 1] = linearSteps.size(); 685 } else if (clauseKeyword == "schedule") { 686 if (checkAllowed(scheduleClause) || 687 parseScheduleClause(parser, schedule, modifiers, scheduleChunkSize)) 688 return failure(); 689 if (scheduleChunkSize) { 690 clauseSegments[pos[scheduleClause]] = 1; 691 } 692 } else if (clauseKeyword == "collapse") { 693 auto type = parser.getBuilder().getI64Type(); 694 mlir::IntegerAttr attr; 695 if (checkAllowed(collapseClause) || parser.parseLParen() || 696 parser.parseAttribute(attr, type) || parser.parseRParen()) 697 return failure(); 698 result.addAttribute("collapse_val", attr); 699 } else if (clauseKeyword == "ordered") { 700 mlir::IntegerAttr attr; 701 if (checkAllowed(orderedClause)) 702 return failure(); 703 if (succeeded(parser.parseOptionalLParen())) { 704 auto type = parser.getBuilder().getI64Type(); 705 if (parser.parseAttribute(attr, type) || parser.parseRParen()) 706 return failure(); 707 } else { 708 // Use 0 to represent no ordered parameter was specified 709 attr = parser.getBuilder().getI64IntegerAttr(0); 710 } 711 result.addAttribute("ordered_val", attr); 712 } else if (clauseKeyword == "order") { 713 StringRef order; 714 if (checkAllowed(orderClause) || parser.parseLParen() || 715 parser.parseKeyword(&order) || parser.parseRParen()) 716 return failure(); 717 auto attr = parser.getBuilder().getStringAttr(order); 718 result.addAttribute("order_val", attr); 719 } else if (clauseKeyword == "memory_order") { 720 StringRef memoryOrder; 721 if (checkAllowed(memoryOrderClause) || parser.parseLParen() || 722 parser.parseKeyword(&memoryOrder) || parser.parseRParen()) 723 return failure(); 724 result.addAttribute("memory_order", 725 parser.getBuilder().getStringAttr(memoryOrder)); 726 } else if (clauseKeyword == "hint") { 727 IntegerAttr hint; 728 if (checkAllowed(hintClause) || 729 parseSynchronizationHint(parser, hint, false)) 730 return failure(); 731 result.addAttribute("hint", hint); 732 } else { 733 return parser.emitError(parser.getNameLoc()) 734 << clauseKeyword << " is not a valid clause"; 735 } 736 } 737 738 // Add if parameter. 739 if (done[ifClause] && clauseSegments[pos[ifClause]] && 740 failed( 741 parser.resolveOperand(ifCond.first, ifCond.second, result.operands))) 742 return failure(); 743 744 // Add num_threads parameter. 745 if (done[numThreadsClause] && clauseSegments[pos[numThreadsClause]] && 746 failed(parser.resolveOperand(numThreads.first, numThreads.second, 747 result.operands))) 748 return failure(); 749 750 // Add private parameters. 751 if (done[privateClause] && clauseSegments[pos[privateClause]] && 752 failed(parser.resolveOperands(privates, privateTypes, 753 privates[0].location, result.operands))) 754 return failure(); 755 756 // Add firstprivate parameters. 757 if (done[firstprivateClause] && clauseSegments[pos[firstprivateClause]] && 758 failed(parser.resolveOperands(firstprivates, firstprivateTypes, 759 firstprivates[0].location, 760 result.operands))) 761 return failure(); 762 763 // Add lastprivate parameters. 764 if (done[lastprivateClause] && clauseSegments[pos[lastprivateClause]] && 765 failed(parser.resolveOperands(lastprivates, lastprivateTypes, 766 lastprivates[0].location, result.operands))) 767 return failure(); 768 769 // Add shared parameters. 770 if (done[sharedClause] && clauseSegments[pos[sharedClause]] && 771 failed(parser.resolveOperands(shareds, sharedTypes, shareds[0].location, 772 result.operands))) 773 return failure(); 774 775 // Add copyin parameters. 776 if (done[copyinClause] && clauseSegments[pos[copyinClause]] && 777 failed(parser.resolveOperands(copyins, copyinTypes, copyins[0].location, 778 result.operands))) 779 return failure(); 780 781 // Add allocate parameters. 782 if (done[allocateClause] && clauseSegments[pos[allocateClause]] && 783 failed(parser.resolveOperands(allocates, allocateTypes, 784 allocates[0].location, result.operands))) 785 return failure(); 786 787 // Add allocator parameters. 788 if (done[allocateClause] && clauseSegments[pos[allocateClause] + 1] && 789 failed(parser.resolveOperands(allocators, allocatorTypes, 790 allocators[0].location, result.operands))) 791 return failure(); 792 793 // Add reduction parameters and symbols 794 if (done[reductionClause] && clauseSegments[pos[reductionClause]]) { 795 if (failed(parser.resolveOperands(reductionVars, reductionVarTypes, 796 parser.getNameLoc(), result.operands))) 797 return failure(); 798 799 SmallVector<Attribute> reductions(reductionSymbols.begin(), 800 reductionSymbols.end()); 801 result.addAttribute("reductions", 802 parser.getBuilder().getArrayAttr(reductions)); 803 } 804 805 // Add linear parameters 806 if (done[linearClause] && clauseSegments[pos[linearClause]]) { 807 auto linearStepType = parser.getBuilder().getI32Type(); 808 SmallVector<Type> linearStepTypes(linearSteps.size(), linearStepType); 809 if (failed(parser.resolveOperands(linears, linearTypes, linears[0].location, 810 result.operands)) || 811 failed(parser.resolveOperands(linearSteps, linearStepTypes, 812 linearSteps[0].location, 813 result.operands))) 814 return failure(); 815 } 816 817 // Add schedule parameters 818 if (done[scheduleClause] && !schedule.empty()) { 819 schedule[0] = llvm::toUpper(schedule[0]); 820 auto attr = parser.getBuilder().getStringAttr(schedule); 821 result.addAttribute("schedule_val", attr); 822 if (modifiers.size() > 0) { 823 auto mod = parser.getBuilder().getStringAttr(modifiers[0]); 824 result.addAttribute("schedule_modifier", mod); 825 } 826 if (scheduleChunkSize) { 827 auto chunkSizeType = parser.getBuilder().getI32Type(); 828 parser.resolveOperand(*scheduleChunkSize, chunkSizeType, result.operands); 829 } 830 } 831 832 segments.insert(segments.end(), clauseSegments.begin(), clauseSegments.end()); 833 834 return success(); 835 } 836 837 /// Parses a parallel operation. 838 /// 839 /// operation ::= `omp.parallel` clause-list 840 /// clause-list ::= clause | clause clause-list 841 /// clause ::= if | num-threads | private | firstprivate | shared | copyin | 842 /// allocate | default | proc-bind 843 /// 844 static ParseResult parseParallelOp(OpAsmParser &parser, 845 OperationState &result) { 846 SmallVector<ClauseType> clauses = { 847 ifClause, numThreadsClause, privateClause, 848 firstprivateClause, sharedClause, copyinClause, 849 allocateClause, defaultClause, procBindClause}; 850 851 SmallVector<int> segments; 852 853 if (failed(parseClauses(parser, result, clauses, segments))) 854 return failure(); 855 856 result.addAttribute("operand_segment_sizes", 857 parser.getBuilder().getI32VectorAttr(segments)); 858 859 Region *body = result.addRegion(); 860 SmallVector<OpAsmParser::OperandType> regionArgs; 861 SmallVector<Type> regionArgTypes; 862 if (parser.parseRegion(*body, regionArgs, regionArgTypes)) 863 return failure(); 864 return success(); 865 } 866 867 //===----------------------------------------------------------------------===// 868 // Parser, printer and verifier for SectionsOp 869 //===----------------------------------------------------------------------===// 870 871 /// Parses an OpenMP Sections operation 872 /// 873 /// sections ::= `omp.sections` clause-list 874 /// clause-list ::= clause clause-list | empty 875 /// clause ::= private | firstprivate | lastprivate | reduction | allocate | 876 /// nowait 877 static ParseResult parseSectionsOp(OpAsmParser &parser, 878 OperationState &result) { 879 880 SmallVector<ClauseType> clauses = {privateClause, firstprivateClause, 881 lastprivateClause, reductionClause, 882 allocateClause, nowaitClause}; 883 884 SmallVector<int> segments; 885 886 if (failed(parseClauses(parser, result, clauses, segments))) 887 return failure(); 888 889 result.addAttribute("operand_segment_sizes", 890 parser.getBuilder().getI32VectorAttr(segments)); 891 892 // Now parse the body. 893 Region *body = result.addRegion(); 894 if (parser.parseRegion(*body)) 895 return failure(); 896 return success(); 897 } 898 899 static void printSectionsOp(OpAsmPrinter &p, SectionsOp op) { 900 p << " "; 901 printDataVars(p, op.private_vars(), "private"); 902 printDataVars(p, op.firstprivate_vars(), "firstprivate"); 903 printDataVars(p, op.lastprivate_vars(), "lastprivate"); 904 905 if (!op.reduction_vars().empty()) 906 printReductionVarList(p, op.reductions(), op.reduction_vars()); 907 908 if (!op.allocate_vars().empty()) 909 printAllocateAndAllocator(p, op.allocate_vars(), op.allocators_vars()); 910 911 if (op.nowait()) 912 p << "nowait "; 913 914 p.printRegion(op.region()); 915 } 916 917 static LogicalResult verifySectionsOp(SectionsOp op) { 918 919 // A list item may not appear in more than one clause on the same directive, 920 // except that it may be specified in both firstprivate and lastprivate 921 // clauses. 922 for (auto var : op.private_vars()) { 923 if (llvm::is_contained(op.firstprivate_vars(), var)) 924 return op.emitOpError() 925 << "operand used in both private and firstprivate clauses"; 926 if (llvm::is_contained(op.lastprivate_vars(), var)) 927 return op.emitOpError() 928 << "operand used in both private and lastprivate clauses"; 929 } 930 931 if (op.allocate_vars().size() != op.allocators_vars().size()) 932 return op.emitError( 933 "expected equal sizes for allocate and allocator variables"); 934 935 for (auto &inst : *op.region().begin()) { 936 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) 937 op.emitOpError() 938 << "expected omp.section op or terminator op inside region"; 939 } 940 941 return verifyReductionVarList(op, op.reductions(), op.reduction_vars()); 942 } 943 944 /// Parses an OpenMP Workshare Loop operation 945 /// 946 /// wsloop ::= `omp.wsloop` loop-control clause-list 947 /// loop-control ::= `(` ssa-id-list `)` `:` type `=` loop-bounds 948 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` inclusive? steps 949 /// steps := `step` `(`ssa-id-list`)` 950 /// clause-list ::= clause clause-list | empty 951 /// clause ::= private | firstprivate | lastprivate | linear | schedule | 952 // collapse | nowait | ordered | order | reduction 953 static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) { 954 955 // Parse an opening `(` followed by induction variables followed by `)` 956 SmallVector<OpAsmParser::OperandType> ivs; 957 if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1, 958 OpAsmParser::Delimiter::Paren)) 959 return failure(); 960 961 int numIVs = static_cast<int>(ivs.size()); 962 Type loopVarType; 963 if (parser.parseColonType(loopVarType)) 964 return failure(); 965 966 // Parse loop bounds. 967 SmallVector<OpAsmParser::OperandType> lower; 968 if (parser.parseEqual() || 969 parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) || 970 parser.resolveOperands(lower, loopVarType, result.operands)) 971 return failure(); 972 973 SmallVector<OpAsmParser::OperandType> upper; 974 if (parser.parseKeyword("to") || 975 parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) || 976 parser.resolveOperands(upper, loopVarType, result.operands)) 977 return failure(); 978 979 if (succeeded(parser.parseOptionalKeyword("inclusive"))) { 980 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 981 result.addAttribute("inclusive", attr); 982 } 983 984 // Parse step values. 985 SmallVector<OpAsmParser::OperandType> steps; 986 if (parser.parseKeyword("step") || 987 parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) || 988 parser.resolveOperands(steps, loopVarType, result.operands)) 989 return failure(); 990 991 SmallVector<ClauseType> clauses = { 992 privateClause, firstprivateClause, lastprivateClause, linearClause, 993 reductionClause, collapseClause, orderClause, orderedClause, 994 nowaitClause, scheduleClause}; 995 SmallVector<int> segments{numIVs, numIVs, numIVs}; 996 if (failed(parseClauses(parser, result, clauses, segments))) 997 return failure(); 998 999 result.addAttribute("operand_segment_sizes", 1000 parser.getBuilder().getI32VectorAttr(segments)); 1001 1002 // Now parse the body. 1003 Region *body = result.addRegion(); 1004 SmallVector<Type> ivTypes(numIVs, loopVarType); 1005 SmallVector<OpAsmParser::OperandType> blockArgs(ivs); 1006 if (parser.parseRegion(*body, blockArgs, ivTypes)) 1007 return failure(); 1008 return success(); 1009 } 1010 1011 static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) { 1012 auto args = op.getRegion().front().getArguments(); 1013 p << " (" << args << ") : " << args[0].getType() << " = (" << op.lowerBound() 1014 << ") to (" << op.upperBound() << ") "; 1015 if (op.inclusive()) { 1016 p << "inclusive "; 1017 } 1018 p << "step (" << op.step() << ") "; 1019 1020 printDataVars(p, op.private_vars(), "private"); 1021 printDataVars(p, op.firstprivate_vars(), "firstprivate"); 1022 printDataVars(p, op.lastprivate_vars(), "lastprivate"); 1023 1024 if (op.linear_vars().size()) 1025 printLinearClause(p, op.linear_vars(), op.linear_step_vars()); 1026 1027 if (auto sched = op.schedule_val()) 1028 printScheduleClause(p, sched.getValue(), op.schedule_modifier(), 1029 op.schedule_chunk_var()); 1030 1031 if (auto collapse = op.collapse_val()) 1032 p << "collapse(" << collapse << ") "; 1033 1034 if (op.nowait()) 1035 p << "nowait "; 1036 1037 if (auto ordered = op.ordered_val()) 1038 p << "ordered(" << ordered << ") "; 1039 1040 if (auto order = op.order_val()) 1041 p << "order(" << order << ") "; 1042 1043 if (!op.reduction_vars().empty()) 1044 printReductionVarList(p, op.reductions(), op.reduction_vars()); 1045 1046 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 1047 } 1048 1049 //===----------------------------------------------------------------------===// 1050 // ReductionOp 1051 //===----------------------------------------------------------------------===// 1052 1053 static ParseResult parseAtomicReductionRegion(OpAsmParser &parser, 1054 Region ®ion) { 1055 if (parser.parseOptionalKeyword("atomic")) 1056 return success(); 1057 return parser.parseRegion(region); 1058 } 1059 1060 static void printAtomicReductionRegion(OpAsmPrinter &printer, 1061 ReductionDeclareOp op, Region ®ion) { 1062 if (region.empty()) 1063 return; 1064 printer << "atomic "; 1065 printer.printRegion(region); 1066 } 1067 1068 static LogicalResult verifyReductionDeclareOp(ReductionDeclareOp op) { 1069 if (op.initializerRegion().empty()) 1070 return op.emitOpError() << "expects non-empty initializer region"; 1071 Block &initializerEntryBlock = op.initializerRegion().front(); 1072 if (initializerEntryBlock.getNumArguments() != 1 || 1073 initializerEntryBlock.getArgument(0).getType() != op.type()) { 1074 return op.emitOpError() << "expects initializer region with one argument " 1075 "of the reduction type"; 1076 } 1077 1078 for (YieldOp yieldOp : op.initializerRegion().getOps<YieldOp>()) { 1079 if (yieldOp.results().size() != 1 || 1080 yieldOp.results().getTypes()[0] != op.type()) 1081 return op.emitOpError() << "expects initializer region to yield a value " 1082 "of the reduction type"; 1083 } 1084 1085 if (op.reductionRegion().empty()) 1086 return op.emitOpError() << "expects non-empty reduction region"; 1087 Block &reductionEntryBlock = op.reductionRegion().front(); 1088 if (reductionEntryBlock.getNumArguments() != 2 || 1089 reductionEntryBlock.getArgumentTypes()[0] != 1090 reductionEntryBlock.getArgumentTypes()[1] || 1091 reductionEntryBlock.getArgumentTypes()[0] != op.type()) 1092 return op.emitOpError() << "expects reduction region with two arguments of " 1093 "the reduction type"; 1094 for (YieldOp yieldOp : op.reductionRegion().getOps<YieldOp>()) { 1095 if (yieldOp.results().size() != 1 || 1096 yieldOp.results().getTypes()[0] != op.type()) 1097 return op.emitOpError() << "expects reduction region to yield a value " 1098 "of the reduction type"; 1099 } 1100 1101 if (op.atomicReductionRegion().empty()) 1102 return success(); 1103 1104 Block &atomicReductionEntryBlock = op.atomicReductionRegion().front(); 1105 if (atomicReductionEntryBlock.getNumArguments() != 2 || 1106 atomicReductionEntryBlock.getArgumentTypes()[0] != 1107 atomicReductionEntryBlock.getArgumentTypes()[1]) 1108 return op.emitOpError() << "expects atomic reduction region with two " 1109 "arguments of the same type"; 1110 auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0] 1111 .dyn_cast<PointerLikeType>(); 1112 if (!ptrType || ptrType.getElementType() != op.type()) 1113 return op.emitOpError() << "expects atomic reduction region arguments to " 1114 "be accumulators containing the reduction type"; 1115 return success(); 1116 } 1117 1118 static LogicalResult verifyReductionOp(ReductionOp op) { 1119 // TODO: generalize this to an op interface when there is more than one op 1120 // that supports reductions. 1121 auto container = op->getParentOfType<WsLoopOp>(); 1122 for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i) 1123 if (container.reduction_vars()[i] == op.accumulator()) 1124 return success(); 1125 1126 return op.emitOpError() << "the accumulator is not used by the parent"; 1127 } 1128 1129 //===----------------------------------------------------------------------===// 1130 // WsLoopOp 1131 //===----------------------------------------------------------------------===// 1132 1133 void WsLoopOp::build(OpBuilder &builder, OperationState &state, 1134 ValueRange lowerBound, ValueRange upperBound, 1135 ValueRange step, ArrayRef<NamedAttribute> attributes) { 1136 build(builder, state, TypeRange(), lowerBound, upperBound, step, 1137 /*private_vars=*/ValueRange(), 1138 /*firstprivate_vars=*/ValueRange(), /*lastprivate_vars=*/ValueRange(), 1139 /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(), 1140 /*reduction_vars=*/ValueRange(), /*schedule_val=*/nullptr, 1141 /*schedule_chunk_var=*/nullptr, /*collapse_val=*/nullptr, 1142 /*nowait=*/nullptr, /*ordered_val=*/nullptr, /*order_val=*/nullptr, 1143 /*inclusive=*/nullptr, /*buildBody=*/false); 1144 state.addAttributes(attributes); 1145 } 1146 1147 void WsLoopOp::build(OpBuilder &, OperationState &state, TypeRange resultTypes, 1148 ValueRange operands, ArrayRef<NamedAttribute> attributes) { 1149 state.addOperands(operands); 1150 state.addAttributes(attributes); 1151 (void)state.addRegion(); 1152 assert(resultTypes.empty() && "mismatched number of return types"); 1153 state.addTypes(resultTypes); 1154 } 1155 1156 void WsLoopOp::build(OpBuilder &builder, OperationState &result, 1157 TypeRange typeRange, ValueRange lowerBounds, 1158 ValueRange upperBounds, ValueRange steps, 1159 ValueRange privateVars, ValueRange firstprivateVars, 1160 ValueRange lastprivateVars, ValueRange linearVars, 1161 ValueRange linearStepVars, ValueRange reductionVars, 1162 StringAttr scheduleVal, Value scheduleChunkVar, 1163 IntegerAttr collapseVal, UnitAttr nowait, 1164 IntegerAttr orderedVal, StringAttr orderVal, 1165 UnitAttr inclusive, bool buildBody) { 1166 result.addOperands(lowerBounds); 1167 result.addOperands(upperBounds); 1168 result.addOperands(steps); 1169 result.addOperands(privateVars); 1170 result.addOperands(firstprivateVars); 1171 result.addOperands(linearVars); 1172 result.addOperands(linearStepVars); 1173 if (scheduleChunkVar) 1174 result.addOperands(scheduleChunkVar); 1175 1176 if (scheduleVal) 1177 result.addAttribute("schedule_val", scheduleVal); 1178 if (collapseVal) 1179 result.addAttribute("collapse_val", collapseVal); 1180 if (nowait) 1181 result.addAttribute("nowait", nowait); 1182 if (orderedVal) 1183 result.addAttribute("ordered_val", orderedVal); 1184 if (orderVal) 1185 result.addAttribute("order", orderVal); 1186 if (inclusive) 1187 result.addAttribute("inclusive", inclusive); 1188 result.addAttribute( 1189 WsLoopOp::getOperandSegmentSizeAttr(), 1190 builder.getI32VectorAttr( 1191 {static_cast<int32_t>(lowerBounds.size()), 1192 static_cast<int32_t>(upperBounds.size()), 1193 static_cast<int32_t>(steps.size()), 1194 static_cast<int32_t>(privateVars.size()), 1195 static_cast<int32_t>(firstprivateVars.size()), 1196 static_cast<int32_t>(lastprivateVars.size()), 1197 static_cast<int32_t>(linearVars.size()), 1198 static_cast<int32_t>(linearStepVars.size()), 1199 static_cast<int32_t>(reductionVars.size()), 1200 static_cast<int32_t>(scheduleChunkVar != nullptr ? 1 : 0)})); 1201 1202 Region *bodyRegion = result.addRegion(); 1203 if (buildBody) { 1204 OpBuilder::InsertionGuard guard(builder); 1205 unsigned numIVs = steps.size(); 1206 SmallVector<Type, 8> argTypes(numIVs, steps.getType().front()); 1207 builder.createBlock(bodyRegion, {}, argTypes); 1208 } 1209 } 1210 1211 static LogicalResult verifyWsLoopOp(WsLoopOp op) { 1212 return verifyReductionVarList(op, op.reductions(), op.reduction_vars()); 1213 } 1214 1215 //===----------------------------------------------------------------------===// 1216 // Verifier for critical construct (2.17.1) 1217 //===----------------------------------------------------------------------===// 1218 1219 static LogicalResult verifyCriticalDeclareOp(CriticalDeclareOp op) { 1220 return verifySynchronizationHint(op, op.hint()); 1221 } 1222 1223 static LogicalResult verifyCriticalOp(CriticalOp op) { 1224 1225 if (op.nameAttr()) { 1226 auto symbolRef = op.nameAttr().cast<SymbolRefAttr>(); 1227 auto decl = 1228 SymbolTable::lookupNearestSymbolFrom<CriticalDeclareOp>(op, symbolRef); 1229 if (!decl) { 1230 return op.emitOpError() << "expected symbol reference " << symbolRef 1231 << " to point to a critical declaration"; 1232 } 1233 } 1234 1235 return success(); 1236 } 1237 1238 //===----------------------------------------------------------------------===// 1239 // Verifier for ordered construct 1240 //===----------------------------------------------------------------------===// 1241 1242 static LogicalResult verifyOrderedOp(OrderedOp op) { 1243 auto container = op->getParentOfType<WsLoopOp>(); 1244 if (!container || !container.ordered_valAttr() || 1245 container.ordered_valAttr().getInt() == 0) 1246 return op.emitOpError() << "ordered depend directive must be closely " 1247 << "nested inside a worksharing-loop with ordered " 1248 << "clause with parameter present"; 1249 1250 if (container.ordered_valAttr().getInt() != 1251 (int64_t)op.num_loops_val().getValue()) 1252 return op.emitOpError() << "number of variables in depend clause does not " 1253 << "match number of iteration variables in the " 1254 << "doacross loop"; 1255 1256 return success(); 1257 } 1258 1259 static LogicalResult verifyOrderedRegionOp(OrderedRegionOp op) { 1260 // TODO: The code generation for ordered simd directive is not supported yet. 1261 if (op.simd()) 1262 return failure(); 1263 1264 if (auto container = op->getParentOfType<WsLoopOp>()) { 1265 if (!container.ordered_valAttr() || 1266 container.ordered_valAttr().getInt() != 0) 1267 return op.emitOpError() << "ordered region must be closely nested inside " 1268 << "a worksharing-loop region with an ordered " 1269 << "clause without parameter present"; 1270 } 1271 1272 return success(); 1273 } 1274 1275 //===----------------------------------------------------------------------===// 1276 // AtomicReadOp 1277 //===----------------------------------------------------------------------===// 1278 1279 /// Parser for AtomicReadOp 1280 /// 1281 /// operation ::= `omp.atomic.read` atomic-clause-list address `->` result-type 1282 /// address ::= operand `:` type 1283 static ParseResult parseAtomicReadOp(OpAsmParser &parser, 1284 OperationState &result) { 1285 OpAsmParser::OperandType address; 1286 Type addressType; 1287 SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause}; 1288 SmallVector<int> segments; 1289 1290 if (parser.parseOperand(address) || 1291 parseClauses(parser, result, clauses, segments) || 1292 parser.parseColonType(addressType) || 1293 parser.resolveOperand(address, addressType, result.operands)) 1294 return failure(); 1295 1296 SmallVector<Type> resultType; 1297 if (parser.parseArrowTypeList(resultType)) 1298 return failure(); 1299 result.addTypes(resultType); 1300 return success(); 1301 } 1302 1303 /// Printer for AtomicReadOp 1304 static void printAtomicReadOp(OpAsmPrinter &p, AtomicReadOp op) { 1305 p << " " << op.address() << " "; 1306 if (op.memory_order()) 1307 p << "memory_order(" << op.memory_order().getValue() << ") "; 1308 if (op.hintAttr()) 1309 printSynchronizationHint(p << " ", op, op.hintAttr()); 1310 p << ": " << op.address().getType() << " -> " << op.getType(); 1311 return; 1312 } 1313 1314 /// Verifier for AtomicReadOp 1315 static LogicalResult verifyAtomicReadOp(AtomicReadOp op) { 1316 if (op.memory_order()) { 1317 StringRef memOrder = op.memory_order().getValue(); 1318 if (memOrder.equals("acq_rel") || memOrder.equals("release")) 1319 return op.emitError( 1320 "memory-order must not be acq_rel or release for atomic reads"); 1321 } 1322 return verifySynchronizationHint(op, op.hint()); 1323 } 1324 1325 //===----------------------------------------------------------------------===// 1326 // AtomicWriteOp 1327 //===----------------------------------------------------------------------===// 1328 1329 /// Parser for AtomicWriteOp 1330 /// 1331 /// operation ::= `omp.atomic.write` atomic-clause-list operands 1332 /// operands ::= address `,` value 1333 /// address ::= operand `:` type 1334 /// value ::= operand `:` type 1335 static ParseResult parseAtomicWriteOp(OpAsmParser &parser, 1336 OperationState &result) { 1337 OpAsmParser::OperandType address, value; 1338 Type addrType, valueType; 1339 SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause}; 1340 SmallVector<int> segments; 1341 1342 if (parser.parseOperand(address) || parser.parseComma() || 1343 parser.parseOperand(value) || 1344 parseClauses(parser, result, clauses, segments) || 1345 parser.parseColonType(addrType) || parser.parseComma() || 1346 parser.parseType(valueType) || 1347 parser.resolveOperand(address, addrType, result.operands) || 1348 parser.resolveOperand(value, valueType, result.operands)) 1349 return failure(); 1350 return success(); 1351 } 1352 1353 /// Printer for AtomicWriteOp 1354 static void printAtomicWriteOp(OpAsmPrinter &p, AtomicWriteOp op) { 1355 p << " " << op.address() << ", " << op.value() << " "; 1356 if (op.memory_order()) 1357 p << "memory_order(" << op.memory_order() << ") "; 1358 if (op.hintAttr()) 1359 printSynchronizationHint(p, op, op.hintAttr()); 1360 p << ": " << op.address().getType() << ", " << op.value().getType(); 1361 return; 1362 } 1363 1364 /// Verifier for AtomicWriteOp 1365 static LogicalResult verifyAtomicWriteOp(AtomicWriteOp op) { 1366 if (op.memory_order()) { 1367 StringRef memoryOrder = op.memory_order().getValue(); 1368 if (memoryOrder.equals("acq_rel") || memoryOrder.equals("acquire")) 1369 return op.emitError( 1370 "memory-order must not be acq_rel or acquire for atomic writes"); 1371 } 1372 return verifySynchronizationHint(op, op.hint()); 1373 } 1374 1375 #define GET_OP_CLASSES 1376 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 1377