1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the OpenMP dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/IR/Attributes.h" 17 #include "mlir/IR/OpImplementation.h" 18 #include "mlir/IR/OperationSupport.h" 19 20 #include "llvm/ADT/BitVector.h" 21 #include "llvm/ADT/SmallString.h" 22 #include "llvm/ADT/StringExtras.h" 23 #include "llvm/ADT/StringRef.h" 24 #include "llvm/ADT/StringSwitch.h" 25 #include <cstddef> 26 27 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc" 28 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc" 29 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc" 30 31 using namespace mlir; 32 using namespace mlir::omp; 33 34 namespace { 35 /// Model for pointer-like types that already provide a `getElementType` method. 36 template <typename T> 37 struct PointerLikeModel 38 : public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> { 39 Type getElementType(Type pointer) const { 40 return pointer.cast<T>().getElementType(); 41 } 42 }; 43 } // end namespace 44 45 void OpenMPDialect::initialize() { 46 addOperations< 47 #define GET_OP_LIST 48 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 49 >(); 50 51 LLVM::LLVMPointerType::attachInterface< 52 PointerLikeModel<LLVM::LLVMPointerType>>(*getContext()); 53 MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext()); 54 } 55 56 //===----------------------------------------------------------------------===// 57 // ParallelOp 58 //===----------------------------------------------------------------------===// 59 60 void ParallelOp::build(OpBuilder &builder, OperationState &state, 61 ArrayRef<NamedAttribute> attributes) { 62 ParallelOp::build( 63 builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr, 64 /*default_val=*/nullptr, /*private_vars=*/ValueRange(), 65 /*firstprivate_vars=*/ValueRange(), /*shared_vars=*/ValueRange(), 66 /*copyin_vars=*/ValueRange(), /*allocate_vars=*/ValueRange(), 67 /*allocators_vars=*/ValueRange(), /*proc_bind_val=*/nullptr); 68 state.addAttributes(attributes); 69 } 70 71 //===----------------------------------------------------------------------===// 72 // Parser and printer for Operand and type list 73 //===----------------------------------------------------------------------===// 74 75 /// Parse a list of operands with types. 76 /// 77 /// operand-and-type-list ::= `(` ssa-id-and-type-list `)` 78 /// ssa-id-and-type-list ::= ssa-id-and-type | 79 /// ssa-id-and-type `,` ssa-id-and-type-list 80 /// ssa-id-and-type ::= ssa-id `:` type 81 static ParseResult 82 parseOperandAndTypeList(OpAsmParser &parser, 83 SmallVectorImpl<OpAsmParser::OperandType> &operands, 84 SmallVectorImpl<Type> &types) { 85 return parser.parseCommaSeparatedList( 86 OpAsmParser::Delimiter::Paren, [&]() -> ParseResult { 87 OpAsmParser::OperandType operand; 88 Type type; 89 if (parser.parseOperand(operand) || parser.parseColonType(type)) 90 return failure(); 91 operands.push_back(operand); 92 types.push_back(type); 93 return success(); 94 }); 95 } 96 97 /// Print an operand and type list with parentheses 98 static void printOperandAndTypeList(OpAsmPrinter &p, OperandRange operands) { 99 p << "("; 100 llvm::interleaveComma( 101 operands, p, [&](const Value &v) { p << v << " : " << v.getType(); }); 102 p << ") "; 103 } 104 105 /// Print data variables corresponding to a data-sharing clause `name` 106 static void printDataVars(OpAsmPrinter &p, OperandRange operands, 107 StringRef name) { 108 if (operands.size()) { 109 p << name; 110 printOperandAndTypeList(p, operands); 111 } 112 } 113 114 //===----------------------------------------------------------------------===// 115 // Parser and printer for Allocate Clause 116 //===----------------------------------------------------------------------===// 117 118 /// Parse an allocate clause with allocators and a list of operands with types. 119 /// 120 /// allocate ::= `allocate` `(` allocate-operand-list `)` 121 /// allocate-operand-list :: = allocate-operand | 122 /// allocator-operand `,` allocate-operand-list 123 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type 124 /// ssa-id-and-type ::= ssa-id `:` type 125 static ParseResult parseAllocateAndAllocator( 126 OpAsmParser &parser, 127 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate, 128 SmallVectorImpl<Type> &typesAllocate, 129 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator, 130 SmallVectorImpl<Type> &typesAllocator) { 131 132 return parser.parseCommaSeparatedList( 133 OpAsmParser::Delimiter::Paren, [&]() -> ParseResult { 134 OpAsmParser::OperandType operand; 135 Type type; 136 if (parser.parseOperand(operand) || parser.parseColonType(type)) 137 return failure(); 138 operandsAllocator.push_back(operand); 139 typesAllocator.push_back(type); 140 if (parser.parseArrow()) 141 return failure(); 142 if (parser.parseOperand(operand) || parser.parseColonType(type)) 143 return failure(); 144 145 operandsAllocate.push_back(operand); 146 typesAllocate.push_back(type); 147 return success(); 148 }); 149 } 150 151 /// Print allocate clause 152 static void printAllocateAndAllocator(OpAsmPrinter &p, 153 OperandRange varsAllocate, 154 OperandRange varsAllocator) { 155 p << "allocate("; 156 for (unsigned i = 0; i < varsAllocate.size(); ++i) { 157 std::string separator = i == varsAllocate.size() - 1 ? ") " : ", "; 158 p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> "; 159 p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator; 160 } 161 } 162 163 static LogicalResult verifyParallelOp(ParallelOp op) { 164 if (op.allocate_vars().size() != op.allocators_vars().size()) 165 return op.emitError( 166 "expected equal sizes for allocate and allocator variables"); 167 return success(); 168 } 169 170 static void printParallelOp(OpAsmPrinter &p, ParallelOp op) { 171 p << " "; 172 if (auto ifCond = op.if_expr_var()) 173 p << "if(" << ifCond << " : " << ifCond.getType() << ") "; 174 175 if (auto threads = op.num_threads_var()) 176 p << "num_threads(" << threads << " : " << threads.getType() << ") "; 177 178 printDataVars(p, op.private_vars(), "private"); 179 printDataVars(p, op.firstprivate_vars(), "firstprivate"); 180 printDataVars(p, op.shared_vars(), "shared"); 181 printDataVars(p, op.copyin_vars(), "copyin"); 182 183 if (!op.allocate_vars().empty()) 184 printAllocateAndAllocator(p, op.allocate_vars(), op.allocators_vars()); 185 186 if (auto def = op.default_val()) 187 p << "default(" << def->drop_front(3) << ") "; 188 189 if (auto bind = op.proc_bind_val()) 190 p << "proc_bind(" << bind << ") "; 191 192 p.printRegion(op.getRegion()); 193 } 194 195 //===----------------------------------------------------------------------===// 196 // Parser and printer for Linear Clause 197 //===----------------------------------------------------------------------===// 198 199 /// linear ::= `linear` `(` linear-list `)` 200 /// linear-list := linear-val | linear-val linear-list 201 /// linear-val := ssa-id-and-type `=` ssa-id-and-type 202 static ParseResult 203 parseLinearClause(OpAsmParser &parser, 204 SmallVectorImpl<OpAsmParser::OperandType> &vars, 205 SmallVectorImpl<Type> &types, 206 SmallVectorImpl<OpAsmParser::OperandType> &stepVars) { 207 if (parser.parseLParen()) 208 return failure(); 209 210 do { 211 OpAsmParser::OperandType var; 212 Type type; 213 OpAsmParser::OperandType stepVar; 214 if (parser.parseOperand(var) || parser.parseEqual() || 215 parser.parseOperand(stepVar) || parser.parseColonType(type)) 216 return failure(); 217 218 vars.push_back(var); 219 types.push_back(type); 220 stepVars.push_back(stepVar); 221 } while (succeeded(parser.parseOptionalComma())); 222 223 if (parser.parseRParen()) 224 return failure(); 225 226 return success(); 227 } 228 229 /// Print Linear Clause 230 static void printLinearClause(OpAsmPrinter &p, OperandRange linearVars, 231 OperandRange linearStepVars) { 232 size_t linearVarsSize = linearVars.size(); 233 p << "linear("; 234 for (unsigned i = 0; i < linearVarsSize; ++i) { 235 std::string separator = i == linearVarsSize - 1 ? ") " : ", "; 236 p << linearVars[i]; 237 if (linearStepVars.size() > i) 238 p << " = " << linearStepVars[i]; 239 p << " : " << linearVars[i].getType() << separator; 240 } 241 } 242 243 //===----------------------------------------------------------------------===// 244 // Parser and printer for Schedule Clause 245 //===----------------------------------------------------------------------===// 246 247 static ParseResult 248 verifyScheduleModifiers(OpAsmParser &parser, 249 SmallVectorImpl<SmallString<12>> &modifiers) { 250 if (modifiers.size() > 2) 251 return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)"; 252 for (auto mod : modifiers) { 253 // Translate the string. If it has no value, then it was not a valid 254 // modifier! 255 auto symbol = symbolizeScheduleModifier(mod); 256 if (!symbol.hasValue()) 257 return parser.emitError(parser.getNameLoc()) 258 << " unknown modifier type: " << mod; 259 } 260 261 // If we have one modifier that is "simd", then stick a "none" modiifer in 262 // index 0. 263 if (modifiers.size() == 1) { 264 if (symbolizeScheduleModifier(modifiers[0]) == 265 mlir::omp::ScheduleModifier::simd) { 266 modifiers.push_back(modifiers[0]); 267 modifiers[0] = 268 stringifyScheduleModifier(mlir::omp::ScheduleModifier::none); 269 } 270 } else if (modifiers.size() == 2) { 271 // If there are two modifier: 272 // First modifier should not be simd, second one should be simd 273 if (symbolizeScheduleModifier(modifiers[0]) == 274 mlir::omp::ScheduleModifier::simd || 275 symbolizeScheduleModifier(modifiers[1]) != 276 mlir::omp::ScheduleModifier::simd) 277 return parser.emitError(parser.getNameLoc()) 278 << " incorrect modifier order"; 279 } 280 return success(); 281 } 282 283 /// schedule ::= `schedule` `(` sched-list `)` 284 /// sched-list ::= sched-val | sched-val sched-list | 285 /// sched-val `,` sched-modifier 286 /// sched-val ::= sched-with-chunk | sched-wo-chunk 287 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)? 288 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided` 289 /// sched-wo-chunk ::= `auto` | `runtime` 290 /// sched-modifier ::= sched-mod-val | sched-mod-val `,` sched-mod-val 291 /// sched-mod-val ::= `monotonic` | `nonmonotonic` | `simd` | `none` 292 static ParseResult 293 parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule, 294 SmallVectorImpl<SmallString<12>> &modifiers, 295 Optional<OpAsmParser::OperandType> &chunkSize) { 296 if (parser.parseLParen()) 297 return failure(); 298 299 StringRef keyword; 300 if (parser.parseKeyword(&keyword)) 301 return failure(); 302 303 schedule = keyword; 304 if (keyword == "static" || keyword == "dynamic" || keyword == "guided") { 305 if (succeeded(parser.parseOptionalEqual())) { 306 chunkSize = OpAsmParser::OperandType{}; 307 if (parser.parseOperand(*chunkSize)) 308 return failure(); 309 } else { 310 chunkSize = llvm::NoneType::None; 311 } 312 } else if (keyword == "auto" || keyword == "runtime") { 313 chunkSize = llvm::NoneType::None; 314 } else { 315 return parser.emitError(parser.getNameLoc()) << " expected schedule kind"; 316 } 317 318 // If there is a comma, we have one or more modifiers.. 319 while (succeeded(parser.parseOptionalComma())) { 320 StringRef mod; 321 if (parser.parseKeyword(&mod)) 322 return failure(); 323 modifiers.push_back(mod); 324 } 325 326 if (parser.parseRParen()) 327 return failure(); 328 329 if (verifyScheduleModifiers(parser, modifiers)) 330 return failure(); 331 332 return success(); 333 } 334 335 /// Print schedule clause 336 static void printScheduleClause(OpAsmPrinter &p, StringRef &sched, 337 llvm::Optional<StringRef> modifier, bool simd, 338 Value scheduleChunkVar) { 339 std::string schedLower = sched.lower(); 340 p << "schedule(" << schedLower; 341 if (scheduleChunkVar) 342 p << " = " << scheduleChunkVar; 343 if (modifier && modifier.hasValue()) 344 p << ", " << modifier; 345 if (simd) 346 p << ", simd"; 347 p << ") "; 348 } 349 350 //===----------------------------------------------------------------------===// 351 // Parser, printer and verifier for ReductionVarList 352 //===----------------------------------------------------------------------===// 353 354 /// reduction ::= `reduction` `(` reduction-entry-list `)` 355 /// reduction-entry-list ::= reduction-entry 356 /// | reduction-entry-list `,` reduction-entry 357 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type 358 static ParseResult 359 parseReductionVarList(OpAsmParser &parser, 360 SmallVectorImpl<SymbolRefAttr> &symbols, 361 SmallVectorImpl<OpAsmParser::OperandType> &operands, 362 SmallVectorImpl<Type> &types) { 363 if (failed(parser.parseLParen())) 364 return failure(); 365 366 do { 367 if (parser.parseAttribute(symbols.emplace_back()) || parser.parseArrow() || 368 parser.parseOperand(operands.emplace_back()) || 369 parser.parseColonType(types.emplace_back())) 370 return failure(); 371 } while (succeeded(parser.parseOptionalComma())); 372 return parser.parseRParen(); 373 } 374 375 /// Print Reduction clause 376 static void printReductionVarList(OpAsmPrinter &p, 377 Optional<ArrayAttr> reductions, 378 OperandRange reduction_vars) { 379 p << "reduction("; 380 for (unsigned i = 0, e = reductions->size(); i < e; ++i) { 381 if (i != 0) 382 p << ", "; 383 p << (*reductions)[i] << " -> " << reduction_vars[i] << " : " 384 << reduction_vars[i].getType(); 385 } 386 p << ") "; 387 } 388 389 /// Verifies Reduction Clause 390 static LogicalResult verifyReductionVarList(Operation *op, 391 Optional<ArrayAttr> reductions, 392 OperandRange reduction_vars) { 393 if (reduction_vars.size() != 0) { 394 if (!reductions || reductions->size() != reduction_vars.size()) 395 return op->emitOpError() 396 << "expected as many reduction symbol references " 397 "as reduction variables"; 398 } else { 399 if (reductions) 400 return op->emitOpError() << "unexpected reduction symbol references"; 401 return success(); 402 } 403 404 DenseSet<Value> accumulators; 405 for (auto args : llvm::zip(reduction_vars, *reductions)) { 406 Value accum = std::get<0>(args); 407 408 if (!accumulators.insert(accum).second) 409 return op->emitOpError() << "accumulator variable used more than once"; 410 411 Type varType = accum.getType().cast<PointerLikeType>(); 412 auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>(); 413 auto decl = 414 SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef); 415 if (!decl) 416 return op->emitOpError() << "expected symbol reference " << symbolRef 417 << " to point to a reduction declaration"; 418 419 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType) 420 return op->emitOpError() 421 << "expected accumulator (" << varType 422 << ") to be the same type as reduction declaration (" 423 << decl.getAccumulatorType() << ")"; 424 } 425 426 return success(); 427 } 428 429 //===----------------------------------------------------------------------===// 430 // Parser, printer and verifier for Synchronization Hint (2.17.12) 431 //===----------------------------------------------------------------------===// 432 433 /// Parses a Synchronization Hint clause. The value of hint is an integer 434 /// which is a combination of different hints from `omp_sync_hint_t`. 435 /// 436 /// hint-clause = `hint` `(` hint-value `)` 437 static ParseResult parseSynchronizationHint(OpAsmParser &parser, 438 IntegerAttr &hintAttr, 439 bool parseKeyword = true) { 440 if (parseKeyword && failed(parser.parseOptionalKeyword("hint"))) { 441 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0); 442 return success(); 443 } 444 445 if (failed(parser.parseLParen())) 446 return failure(); 447 StringRef hintKeyword; 448 int64_t hint = 0; 449 do { 450 if (failed(parser.parseKeyword(&hintKeyword))) 451 return failure(); 452 if (hintKeyword == "uncontended") 453 hint |= 1; 454 else if (hintKeyword == "contended") 455 hint |= 2; 456 else if (hintKeyword == "nonspeculative") 457 hint |= 4; 458 else if (hintKeyword == "speculative") 459 hint |= 8; 460 else 461 return parser.emitError(parser.getCurrentLocation()) 462 << hintKeyword << " is not a valid hint"; 463 } while (succeeded(parser.parseOptionalComma())); 464 if (failed(parser.parseRParen())) 465 return failure(); 466 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint); 467 return success(); 468 } 469 470 /// Prints a Synchronization Hint clause 471 static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, 472 IntegerAttr hintAttr) { 473 int64_t hint = hintAttr.getInt(); 474 475 if (hint == 0) 476 return; 477 478 // Helper function to get n-th bit from the right end of `value` 479 auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; 480 481 bool uncontended = bitn(hint, 0); 482 bool contended = bitn(hint, 1); 483 bool nonspeculative = bitn(hint, 2); 484 bool speculative = bitn(hint, 3); 485 486 SmallVector<StringRef> hints; 487 if (uncontended) 488 hints.push_back("uncontended"); 489 if (contended) 490 hints.push_back("contended"); 491 if (nonspeculative) 492 hints.push_back("nonspeculative"); 493 if (speculative) 494 hints.push_back("speculative"); 495 496 p << "hint("; 497 llvm::interleaveComma(hints, p); 498 p << ") "; 499 } 500 501 /// Verifies a synchronization hint clause 502 static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) { 503 504 // Helper function to get n-th bit from the right end of `value` 505 auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; 506 507 bool uncontended = bitn(hint, 0); 508 bool contended = bitn(hint, 1); 509 bool nonspeculative = bitn(hint, 2); 510 bool speculative = bitn(hint, 3); 511 512 if (uncontended && contended) 513 return op->emitOpError() << "the hints omp_sync_hint_uncontended and " 514 "omp_sync_hint_contended cannot be combined"; 515 if (nonspeculative && speculative) 516 return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and " 517 "omp_sync_hint_speculative cannot be combined."; 518 return success(); 519 } 520 521 enum ClauseType { 522 ifClause, 523 numThreadsClause, 524 privateClause, 525 firstprivateClause, 526 lastprivateClause, 527 sharedClause, 528 copyinClause, 529 allocateClause, 530 defaultClause, 531 procBindClause, 532 reductionClause, 533 nowaitClause, 534 linearClause, 535 scheduleClause, 536 collapseClause, 537 orderClause, 538 orderedClause, 539 memoryOrderClause, 540 hintClause, 541 COUNT 542 }; 543 544 //===----------------------------------------------------------------------===// 545 // Parser for Clause List 546 //===----------------------------------------------------------------------===// 547 548 /// Parse a list of clauses. The clauses can appear in any order, but their 549 /// operand segment indices are in the same order that they are passed in the 550 /// `clauses` list. The operand segments are added over the prevSegments 551 552 /// clause-list ::= clause clause-list | empty 553 /// clause ::= if | num-threads | private | firstprivate | lastprivate | 554 /// shared | copyin | allocate | default | proc-bind | reduction | 555 /// nowait | linear | schedule | collapse | order | ordered | 556 /// inclusive 557 /// if ::= `if` `(` ssa-id-and-type `)` 558 /// num-threads ::= `num_threads` `(` ssa-id-and-type `)` 559 /// private ::= `private` operand-and-type-list 560 /// firstprivate ::= `firstprivate` operand-and-type-list 561 /// lastprivate ::= `lastprivate` operand-and-type-list 562 /// shared ::= `shared` operand-and-type-list 563 /// copyin ::= `copyin` operand-and-type-list 564 /// allocate ::= `allocate` `(` allocate-operand-list `)` 565 /// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`) 566 /// proc-bind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)` 567 /// reduction ::= `reduction` `(` reduction-entry-list `)` 568 /// nowait ::= `nowait` 569 /// linear ::= `linear` `(` linear-list `)` 570 /// schedule ::= `schedule` `(` sched-list `)` 571 /// collapse ::= `collapse` `(` ssa-id-and-type `)` 572 /// order ::= `order` `(` `concurrent` `)` 573 /// ordered ::= `ordered` `(` ssa-id-and-type `)` 574 /// inclusive ::= `inclusive` 575 /// 576 /// Note that each clause can only appear once in the clase-list. 577 static ParseResult parseClauses(OpAsmParser &parser, OperationState &result, 578 SmallVectorImpl<ClauseType> &clauses, 579 SmallVectorImpl<int> &segments) { 580 581 // Check done[clause] to see if it has been parsed already 582 llvm::BitVector done(ClauseType::COUNT, false); 583 584 // See pos[clause] to get position of clause in operand segments 585 SmallVector<int> pos(ClauseType::COUNT, -1); 586 587 // Stores the last parsed clause keyword 588 StringRef clauseKeyword; 589 StringRef opName = result.name.getStringRef(); 590 591 // Containers for storing operands, types and attributes for various clauses 592 std::pair<OpAsmParser::OperandType, Type> ifCond; 593 std::pair<OpAsmParser::OperandType, Type> numThreads; 594 595 SmallVector<OpAsmParser::OperandType> privates, firstprivates, lastprivates, 596 shareds, copyins; 597 SmallVector<Type> privateTypes, firstprivateTypes, lastprivateTypes, 598 sharedTypes, copyinTypes; 599 600 SmallVector<OpAsmParser::OperandType> allocates, allocators; 601 SmallVector<Type> allocateTypes, allocatorTypes; 602 603 SmallVector<SymbolRefAttr> reductionSymbols; 604 SmallVector<OpAsmParser::OperandType> reductionVars; 605 SmallVector<Type> reductionVarTypes; 606 607 SmallVector<OpAsmParser::OperandType> linears; 608 SmallVector<Type> linearTypes; 609 SmallVector<OpAsmParser::OperandType> linearSteps; 610 611 SmallString<8> schedule; 612 SmallVector<SmallString<12>> modifiers; 613 Optional<OpAsmParser::OperandType> scheduleChunkSize; 614 615 // Compute the position of clauses in operand segments 616 int currPos = 0; 617 for (ClauseType clause : clauses) { 618 619 // Skip the following clauses - they do not take any position in operand 620 // segments 621 if (clause == defaultClause || clause == procBindClause || 622 clause == nowaitClause || clause == collapseClause || 623 clause == orderClause || clause == orderedClause) 624 continue; 625 626 pos[clause] = currPos++; 627 628 // For the following clauses, two positions are reserved in the operand 629 // segments 630 if (clause == allocateClause || clause == linearClause) 631 currPos++; 632 } 633 634 SmallVector<int> clauseSegments(currPos); 635 636 // Helper function to check if a clause is allowed/repeated or not 637 auto checkAllowed = [&](ClauseType clause, 638 bool allowRepeat = false) -> ParseResult { 639 if (!llvm::is_contained(clauses, clause)) 640 return parser.emitError(parser.getCurrentLocation()) 641 << clauseKeyword << " is not a valid clause for the " << opName 642 << " operation"; 643 if (done[clause] && !allowRepeat) 644 return parser.emitError(parser.getCurrentLocation()) 645 << "at most one " << clauseKeyword << " clause can appear on the " 646 << opName << " operation"; 647 done[clause] = true; 648 return success(); 649 }; 650 651 while (succeeded(parser.parseOptionalKeyword(&clauseKeyword))) { 652 if (clauseKeyword == "if") { 653 if (checkAllowed(ifClause) || parser.parseLParen() || 654 parser.parseOperand(ifCond.first) || 655 parser.parseColonType(ifCond.second) || parser.parseRParen()) 656 return failure(); 657 clauseSegments[pos[ifClause]] = 1; 658 } else if (clauseKeyword == "num_threads") { 659 if (checkAllowed(numThreadsClause) || parser.parseLParen() || 660 parser.parseOperand(numThreads.first) || 661 parser.parseColonType(numThreads.second) || parser.parseRParen()) 662 return failure(); 663 clauseSegments[pos[numThreadsClause]] = 1; 664 } else if (clauseKeyword == "private") { 665 if (checkAllowed(privateClause) || 666 parseOperandAndTypeList(parser, privates, privateTypes)) 667 return failure(); 668 clauseSegments[pos[privateClause]] = privates.size(); 669 } else if (clauseKeyword == "firstprivate") { 670 if (checkAllowed(firstprivateClause) || 671 parseOperandAndTypeList(parser, firstprivates, firstprivateTypes)) 672 return failure(); 673 clauseSegments[pos[firstprivateClause]] = firstprivates.size(); 674 } else if (clauseKeyword == "lastprivate") { 675 if (checkAllowed(lastprivateClause) || 676 parseOperandAndTypeList(parser, lastprivates, lastprivateTypes)) 677 return failure(); 678 clauseSegments[pos[lastprivateClause]] = lastprivates.size(); 679 } else if (clauseKeyword == "shared") { 680 if (checkAllowed(sharedClause) || 681 parseOperandAndTypeList(parser, shareds, sharedTypes)) 682 return failure(); 683 clauseSegments[pos[sharedClause]] = shareds.size(); 684 } else if (clauseKeyword == "copyin") { 685 if (checkAllowed(copyinClause) || 686 parseOperandAndTypeList(parser, copyins, copyinTypes)) 687 return failure(); 688 clauseSegments[pos[copyinClause]] = copyins.size(); 689 } else if (clauseKeyword == "allocate") { 690 if (checkAllowed(allocateClause) || 691 parseAllocateAndAllocator(parser, allocates, allocateTypes, 692 allocators, allocatorTypes)) 693 return failure(); 694 clauseSegments[pos[allocateClause]] = allocates.size(); 695 clauseSegments[pos[allocateClause] + 1] = allocators.size(); 696 } else if (clauseKeyword == "default") { 697 StringRef defval; 698 if (checkAllowed(defaultClause) || parser.parseLParen() || 699 parser.parseKeyword(&defval) || parser.parseRParen()) 700 return failure(); 701 // The def prefix is required for the attribute as "private" is a keyword 702 // in C++. 703 auto attr = parser.getBuilder().getStringAttr("def" + defval); 704 result.addAttribute("default_val", attr); 705 } else if (clauseKeyword == "proc_bind") { 706 StringRef bind; 707 if (checkAllowed(procBindClause) || parser.parseLParen() || 708 parser.parseKeyword(&bind) || parser.parseRParen()) 709 return failure(); 710 auto attr = parser.getBuilder().getStringAttr(bind); 711 result.addAttribute("proc_bind_val", attr); 712 } else if (clauseKeyword == "reduction") { 713 if (checkAllowed(reductionClause) || 714 parseReductionVarList(parser, reductionSymbols, reductionVars, 715 reductionVarTypes)) 716 return failure(); 717 clauseSegments[pos[reductionClause]] = reductionVars.size(); 718 } else if (clauseKeyword == "nowait") { 719 if (checkAllowed(nowaitClause)) 720 return failure(); 721 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 722 result.addAttribute("nowait", attr); 723 } else if (clauseKeyword == "linear") { 724 if (checkAllowed(linearClause) || 725 parseLinearClause(parser, linears, linearTypes, linearSteps)) 726 return failure(); 727 clauseSegments[pos[linearClause]] = linears.size(); 728 clauseSegments[pos[linearClause] + 1] = linearSteps.size(); 729 } else if (clauseKeyword == "schedule") { 730 if (checkAllowed(scheduleClause) || 731 parseScheduleClause(parser, schedule, modifiers, scheduleChunkSize)) 732 return failure(); 733 if (scheduleChunkSize) { 734 clauseSegments[pos[scheduleClause]] = 1; 735 } 736 } else if (clauseKeyword == "collapse") { 737 auto type = parser.getBuilder().getI64Type(); 738 mlir::IntegerAttr attr; 739 if (checkAllowed(collapseClause) || parser.parseLParen() || 740 parser.parseAttribute(attr, type) || parser.parseRParen()) 741 return failure(); 742 result.addAttribute("collapse_val", attr); 743 } else if (clauseKeyword == "ordered") { 744 mlir::IntegerAttr attr; 745 if (checkAllowed(orderedClause)) 746 return failure(); 747 if (succeeded(parser.parseOptionalLParen())) { 748 auto type = parser.getBuilder().getI64Type(); 749 if (parser.parseAttribute(attr, type) || parser.parseRParen()) 750 return failure(); 751 } else { 752 // Use 0 to represent no ordered parameter was specified 753 attr = parser.getBuilder().getI64IntegerAttr(0); 754 } 755 result.addAttribute("ordered_val", attr); 756 } else if (clauseKeyword == "order") { 757 StringRef order; 758 if (checkAllowed(orderClause) || parser.parseLParen() || 759 parser.parseKeyword(&order) || parser.parseRParen()) 760 return failure(); 761 auto attr = parser.getBuilder().getStringAttr(order); 762 result.addAttribute("order_val", attr); 763 } else if (clauseKeyword == "memory_order") { 764 StringRef memoryOrder; 765 if (checkAllowed(memoryOrderClause) || parser.parseLParen() || 766 parser.parseKeyword(&memoryOrder) || parser.parseRParen()) 767 return failure(); 768 result.addAttribute("memory_order", 769 parser.getBuilder().getStringAttr(memoryOrder)); 770 } else if (clauseKeyword == "hint") { 771 IntegerAttr hint; 772 if (checkAllowed(hintClause) || 773 parseSynchronizationHint(parser, hint, false)) 774 return failure(); 775 result.addAttribute("hint", hint); 776 } else { 777 return parser.emitError(parser.getNameLoc()) 778 << clauseKeyword << " is not a valid clause"; 779 } 780 } 781 782 // Add if parameter. 783 if (done[ifClause] && clauseSegments[pos[ifClause]] && 784 failed( 785 parser.resolveOperand(ifCond.first, ifCond.second, result.operands))) 786 return failure(); 787 788 // Add num_threads parameter. 789 if (done[numThreadsClause] && clauseSegments[pos[numThreadsClause]] && 790 failed(parser.resolveOperand(numThreads.first, numThreads.second, 791 result.operands))) 792 return failure(); 793 794 // Add private parameters. 795 if (done[privateClause] && clauseSegments[pos[privateClause]] && 796 failed(parser.resolveOperands(privates, privateTypes, 797 privates[0].location, result.operands))) 798 return failure(); 799 800 // Add firstprivate parameters. 801 if (done[firstprivateClause] && clauseSegments[pos[firstprivateClause]] && 802 failed(parser.resolveOperands(firstprivates, firstprivateTypes, 803 firstprivates[0].location, 804 result.operands))) 805 return failure(); 806 807 // Add lastprivate parameters. 808 if (done[lastprivateClause] && clauseSegments[pos[lastprivateClause]] && 809 failed(parser.resolveOperands(lastprivates, lastprivateTypes, 810 lastprivates[0].location, result.operands))) 811 return failure(); 812 813 // Add shared parameters. 814 if (done[sharedClause] && clauseSegments[pos[sharedClause]] && 815 failed(parser.resolveOperands(shareds, sharedTypes, shareds[0].location, 816 result.operands))) 817 return failure(); 818 819 // Add copyin parameters. 820 if (done[copyinClause] && clauseSegments[pos[copyinClause]] && 821 failed(parser.resolveOperands(copyins, copyinTypes, copyins[0].location, 822 result.operands))) 823 return failure(); 824 825 // Add allocate parameters. 826 if (done[allocateClause] && clauseSegments[pos[allocateClause]] && 827 failed(parser.resolveOperands(allocates, allocateTypes, 828 allocates[0].location, result.operands))) 829 return failure(); 830 831 // Add allocator parameters. 832 if (done[allocateClause] && clauseSegments[pos[allocateClause] + 1] && 833 failed(parser.resolveOperands(allocators, allocatorTypes, 834 allocators[0].location, result.operands))) 835 return failure(); 836 837 // Add reduction parameters and symbols 838 if (done[reductionClause] && clauseSegments[pos[reductionClause]]) { 839 if (failed(parser.resolveOperands(reductionVars, reductionVarTypes, 840 parser.getNameLoc(), result.operands))) 841 return failure(); 842 843 SmallVector<Attribute> reductions(reductionSymbols.begin(), 844 reductionSymbols.end()); 845 result.addAttribute("reductions", 846 parser.getBuilder().getArrayAttr(reductions)); 847 } 848 849 // Add linear parameters 850 if (done[linearClause] && clauseSegments[pos[linearClause]]) { 851 auto linearStepType = parser.getBuilder().getI32Type(); 852 SmallVector<Type> linearStepTypes(linearSteps.size(), linearStepType); 853 if (failed(parser.resolveOperands(linears, linearTypes, linears[0].location, 854 result.operands)) || 855 failed(parser.resolveOperands(linearSteps, linearStepTypes, 856 linearSteps[0].location, 857 result.operands))) 858 return failure(); 859 } 860 861 // Add schedule parameters 862 if (done[scheduleClause] && !schedule.empty()) { 863 schedule[0] = llvm::toUpper(schedule[0]); 864 auto attr = parser.getBuilder().getStringAttr(schedule); 865 result.addAttribute("schedule_val", attr); 866 if (modifiers.size() > 0) { 867 auto mod = parser.getBuilder().getStringAttr(modifiers[0]); 868 result.addAttribute("schedule_modifier", mod); 869 // Only SIMD attribute is allowed here! 870 if (modifiers.size() > 1) { 871 assert(symbolizeScheduleModifier(modifiers[1]) == 872 mlir::omp::ScheduleModifier::simd); 873 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 874 result.addAttribute("simd_modifier", attr); 875 } 876 } 877 if (scheduleChunkSize) { 878 auto chunkSizeType = parser.getBuilder().getI32Type(); 879 parser.resolveOperand(*scheduleChunkSize, chunkSizeType, result.operands); 880 } 881 } 882 883 segments.insert(segments.end(), clauseSegments.begin(), clauseSegments.end()); 884 885 return success(); 886 } 887 888 /// Parses a parallel operation. 889 /// 890 /// operation ::= `omp.parallel` clause-list 891 /// clause-list ::= clause | clause clause-list 892 /// clause ::= if | num-threads | private | firstprivate | shared | copyin | 893 /// allocate | default | proc-bind 894 /// 895 static ParseResult parseParallelOp(OpAsmParser &parser, 896 OperationState &result) { 897 SmallVector<ClauseType> clauses = { 898 ifClause, numThreadsClause, privateClause, 899 firstprivateClause, sharedClause, copyinClause, 900 allocateClause, defaultClause, procBindClause}; 901 902 SmallVector<int> segments; 903 904 if (failed(parseClauses(parser, result, clauses, segments))) 905 return failure(); 906 907 result.addAttribute("operand_segment_sizes", 908 parser.getBuilder().getI32VectorAttr(segments)); 909 910 Region *body = result.addRegion(); 911 SmallVector<OpAsmParser::OperandType> regionArgs; 912 SmallVector<Type> regionArgTypes; 913 if (parser.parseRegion(*body, regionArgs, regionArgTypes)) 914 return failure(); 915 return success(); 916 } 917 918 //===----------------------------------------------------------------------===// 919 // Parser, printer and verifier for SectionsOp 920 //===----------------------------------------------------------------------===// 921 922 /// Parses an OpenMP Sections operation 923 /// 924 /// sections ::= `omp.sections` clause-list 925 /// clause-list ::= clause clause-list | empty 926 /// clause ::= private | firstprivate | lastprivate | reduction | allocate | 927 /// nowait 928 static ParseResult parseSectionsOp(OpAsmParser &parser, 929 OperationState &result) { 930 931 SmallVector<ClauseType> clauses = {privateClause, firstprivateClause, 932 lastprivateClause, reductionClause, 933 allocateClause, nowaitClause}; 934 935 SmallVector<int> segments; 936 937 if (failed(parseClauses(parser, result, clauses, segments))) 938 return failure(); 939 940 result.addAttribute("operand_segment_sizes", 941 parser.getBuilder().getI32VectorAttr(segments)); 942 943 // Now parse the body. 944 Region *body = result.addRegion(); 945 if (parser.parseRegion(*body)) 946 return failure(); 947 return success(); 948 } 949 950 static void printSectionsOp(OpAsmPrinter &p, SectionsOp op) { 951 p << " "; 952 printDataVars(p, op.private_vars(), "private"); 953 printDataVars(p, op.firstprivate_vars(), "firstprivate"); 954 printDataVars(p, op.lastprivate_vars(), "lastprivate"); 955 956 if (!op.reduction_vars().empty()) 957 printReductionVarList(p, op.reductions(), op.reduction_vars()); 958 959 if (!op.allocate_vars().empty()) 960 printAllocateAndAllocator(p, op.allocate_vars(), op.allocators_vars()); 961 962 if (op.nowait()) 963 p << "nowait "; 964 965 p.printRegion(op.region()); 966 } 967 968 static LogicalResult verifySectionsOp(SectionsOp op) { 969 970 // A list item may not appear in more than one clause on the same directive, 971 // except that it may be specified in both firstprivate and lastprivate 972 // clauses. 973 for (auto var : op.private_vars()) { 974 if (llvm::is_contained(op.firstprivate_vars(), var)) 975 return op.emitOpError() 976 << "operand used in both private and firstprivate clauses"; 977 if (llvm::is_contained(op.lastprivate_vars(), var)) 978 return op.emitOpError() 979 << "operand used in both private and lastprivate clauses"; 980 } 981 982 if (op.allocate_vars().size() != op.allocators_vars().size()) 983 return op.emitError( 984 "expected equal sizes for allocate and allocator variables"); 985 986 for (auto &inst : *op.region().begin()) { 987 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) 988 op.emitOpError() 989 << "expected omp.section op or terminator op inside region"; 990 } 991 992 return verifyReductionVarList(op, op.reductions(), op.reduction_vars()); 993 } 994 995 /// Parses an OpenMP Workshare Loop operation 996 /// 997 /// wsloop ::= `omp.wsloop` loop-control clause-list 998 /// loop-control ::= `(` ssa-id-list `)` `:` type `=` loop-bounds 999 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` inclusive? steps 1000 /// steps := `step` `(`ssa-id-list`)` 1001 /// clause-list ::= clause clause-list | empty 1002 /// clause ::= private | firstprivate | lastprivate | linear | schedule | 1003 // collapse | nowait | ordered | order | reduction 1004 static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) { 1005 1006 // Parse an opening `(` followed by induction variables followed by `)` 1007 SmallVector<OpAsmParser::OperandType> ivs; 1008 if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1, 1009 OpAsmParser::Delimiter::Paren)) 1010 return failure(); 1011 1012 int numIVs = static_cast<int>(ivs.size()); 1013 Type loopVarType; 1014 if (parser.parseColonType(loopVarType)) 1015 return failure(); 1016 1017 // Parse loop bounds. 1018 SmallVector<OpAsmParser::OperandType> lower; 1019 if (parser.parseEqual() || 1020 parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) || 1021 parser.resolveOperands(lower, loopVarType, result.operands)) 1022 return failure(); 1023 1024 SmallVector<OpAsmParser::OperandType> upper; 1025 if (parser.parseKeyword("to") || 1026 parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) || 1027 parser.resolveOperands(upper, loopVarType, result.operands)) 1028 return failure(); 1029 1030 if (succeeded(parser.parseOptionalKeyword("inclusive"))) { 1031 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 1032 result.addAttribute("inclusive", attr); 1033 } 1034 1035 // Parse step values. 1036 SmallVector<OpAsmParser::OperandType> steps; 1037 if (parser.parseKeyword("step") || 1038 parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) || 1039 parser.resolveOperands(steps, loopVarType, result.operands)) 1040 return failure(); 1041 1042 SmallVector<ClauseType> clauses = { 1043 privateClause, firstprivateClause, lastprivateClause, linearClause, 1044 reductionClause, collapseClause, orderClause, orderedClause, 1045 nowaitClause, scheduleClause}; 1046 SmallVector<int> segments{numIVs, numIVs, numIVs}; 1047 if (failed(parseClauses(parser, result, clauses, segments))) 1048 return failure(); 1049 1050 result.addAttribute("operand_segment_sizes", 1051 parser.getBuilder().getI32VectorAttr(segments)); 1052 1053 // Now parse the body. 1054 Region *body = result.addRegion(); 1055 SmallVector<Type> ivTypes(numIVs, loopVarType); 1056 SmallVector<OpAsmParser::OperandType> blockArgs(ivs); 1057 if (parser.parseRegion(*body, blockArgs, ivTypes)) 1058 return failure(); 1059 return success(); 1060 } 1061 1062 static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) { 1063 auto args = op.getRegion().front().getArguments(); 1064 p << " (" << args << ") : " << args[0].getType() << " = (" << op.lowerBound() 1065 << ") to (" << op.upperBound() << ") "; 1066 if (op.inclusive()) { 1067 p << "inclusive "; 1068 } 1069 p << "step (" << op.step() << ") "; 1070 1071 printDataVars(p, op.private_vars(), "private"); 1072 printDataVars(p, op.firstprivate_vars(), "firstprivate"); 1073 printDataVars(p, op.lastprivate_vars(), "lastprivate"); 1074 1075 if (op.linear_vars().size()) 1076 printLinearClause(p, op.linear_vars(), op.linear_step_vars()); 1077 1078 if (auto sched = op.schedule_val()) 1079 printScheduleClause(p, sched.getValue(), op.schedule_modifier(), 1080 op.simd_modifier(), op.schedule_chunk_var()); 1081 1082 if (auto collapse = op.collapse_val()) 1083 p << "collapse(" << collapse << ") "; 1084 1085 if (op.nowait()) 1086 p << "nowait "; 1087 1088 if (auto ordered = op.ordered_val()) 1089 p << "ordered(" << ordered << ") "; 1090 1091 if (auto order = op.order_val()) 1092 p << "order(" << order << ") "; 1093 1094 if (!op.reduction_vars().empty()) 1095 printReductionVarList(p, op.reductions(), op.reduction_vars()); 1096 1097 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 1098 } 1099 1100 //===----------------------------------------------------------------------===// 1101 // ReductionOp 1102 //===----------------------------------------------------------------------===// 1103 1104 static ParseResult parseAtomicReductionRegion(OpAsmParser &parser, 1105 Region ®ion) { 1106 if (parser.parseOptionalKeyword("atomic")) 1107 return success(); 1108 return parser.parseRegion(region); 1109 } 1110 1111 static void printAtomicReductionRegion(OpAsmPrinter &printer, 1112 ReductionDeclareOp op, Region ®ion) { 1113 if (region.empty()) 1114 return; 1115 printer << "atomic "; 1116 printer.printRegion(region); 1117 } 1118 1119 static LogicalResult verifyReductionDeclareOp(ReductionDeclareOp op) { 1120 if (op.initializerRegion().empty()) 1121 return op.emitOpError() << "expects non-empty initializer region"; 1122 Block &initializerEntryBlock = op.initializerRegion().front(); 1123 if (initializerEntryBlock.getNumArguments() != 1 || 1124 initializerEntryBlock.getArgument(0).getType() != op.type()) { 1125 return op.emitOpError() << "expects initializer region with one argument " 1126 "of the reduction type"; 1127 } 1128 1129 for (YieldOp yieldOp : op.initializerRegion().getOps<YieldOp>()) { 1130 if (yieldOp.results().size() != 1 || 1131 yieldOp.results().getTypes()[0] != op.type()) 1132 return op.emitOpError() << "expects initializer region to yield a value " 1133 "of the reduction type"; 1134 } 1135 1136 if (op.reductionRegion().empty()) 1137 return op.emitOpError() << "expects non-empty reduction region"; 1138 Block &reductionEntryBlock = op.reductionRegion().front(); 1139 if (reductionEntryBlock.getNumArguments() != 2 || 1140 reductionEntryBlock.getArgumentTypes()[0] != 1141 reductionEntryBlock.getArgumentTypes()[1] || 1142 reductionEntryBlock.getArgumentTypes()[0] != op.type()) 1143 return op.emitOpError() << "expects reduction region with two arguments of " 1144 "the reduction type"; 1145 for (YieldOp yieldOp : op.reductionRegion().getOps<YieldOp>()) { 1146 if (yieldOp.results().size() != 1 || 1147 yieldOp.results().getTypes()[0] != op.type()) 1148 return op.emitOpError() << "expects reduction region to yield a value " 1149 "of the reduction type"; 1150 } 1151 1152 if (op.atomicReductionRegion().empty()) 1153 return success(); 1154 1155 Block &atomicReductionEntryBlock = op.atomicReductionRegion().front(); 1156 if (atomicReductionEntryBlock.getNumArguments() != 2 || 1157 atomicReductionEntryBlock.getArgumentTypes()[0] != 1158 atomicReductionEntryBlock.getArgumentTypes()[1]) 1159 return op.emitOpError() << "expects atomic reduction region with two " 1160 "arguments of the same type"; 1161 auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0] 1162 .dyn_cast<PointerLikeType>(); 1163 if (!ptrType || ptrType.getElementType() != op.type()) 1164 return op.emitOpError() << "expects atomic reduction region arguments to " 1165 "be accumulators containing the reduction type"; 1166 return success(); 1167 } 1168 1169 static LogicalResult verifyReductionOp(ReductionOp op) { 1170 // TODO: generalize this to an op interface when there is more than one op 1171 // that supports reductions. 1172 auto container = op->getParentOfType<WsLoopOp>(); 1173 for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i) 1174 if (container.reduction_vars()[i] == op.accumulator()) 1175 return success(); 1176 1177 return op.emitOpError() << "the accumulator is not used by the parent"; 1178 } 1179 1180 //===----------------------------------------------------------------------===// 1181 // WsLoopOp 1182 //===----------------------------------------------------------------------===// 1183 1184 void WsLoopOp::build(OpBuilder &builder, OperationState &state, 1185 ValueRange lowerBound, ValueRange upperBound, 1186 ValueRange step, ArrayRef<NamedAttribute> attributes) { 1187 build(builder, state, TypeRange(), lowerBound, upperBound, step, 1188 /*private_vars=*/ValueRange(), 1189 /*firstprivate_vars=*/ValueRange(), /*lastprivate_vars=*/ValueRange(), 1190 /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(), 1191 /*reduction_vars=*/ValueRange(), /*schedule_val=*/nullptr, 1192 /*schedule_chunk_var=*/nullptr, /*collapse_val=*/nullptr, 1193 /*nowait=*/nullptr, /*ordered_val=*/nullptr, /*order_val=*/nullptr, 1194 /*inclusive=*/nullptr, /*buildBody=*/false); 1195 state.addAttributes(attributes); 1196 } 1197 1198 void WsLoopOp::build(OpBuilder &, OperationState &state, TypeRange resultTypes, 1199 ValueRange operands, ArrayRef<NamedAttribute> attributes) { 1200 state.addOperands(operands); 1201 state.addAttributes(attributes); 1202 (void)state.addRegion(); 1203 assert(resultTypes.empty() && "mismatched number of return types"); 1204 state.addTypes(resultTypes); 1205 } 1206 1207 void WsLoopOp::build(OpBuilder &builder, OperationState &result, 1208 TypeRange typeRange, ValueRange lowerBounds, 1209 ValueRange upperBounds, ValueRange steps, 1210 ValueRange privateVars, ValueRange firstprivateVars, 1211 ValueRange lastprivateVars, ValueRange linearVars, 1212 ValueRange linearStepVars, ValueRange reductionVars, 1213 StringAttr scheduleVal, Value scheduleChunkVar, 1214 IntegerAttr collapseVal, UnitAttr nowait, 1215 IntegerAttr orderedVal, StringAttr orderVal, 1216 UnitAttr inclusive, bool buildBody) { 1217 result.addOperands(lowerBounds); 1218 result.addOperands(upperBounds); 1219 result.addOperands(steps); 1220 result.addOperands(privateVars); 1221 result.addOperands(firstprivateVars); 1222 result.addOperands(linearVars); 1223 result.addOperands(linearStepVars); 1224 if (scheduleChunkVar) 1225 result.addOperands(scheduleChunkVar); 1226 1227 if (scheduleVal) 1228 result.addAttribute("schedule_val", scheduleVal); 1229 if (collapseVal) 1230 result.addAttribute("collapse_val", collapseVal); 1231 if (nowait) 1232 result.addAttribute("nowait", nowait); 1233 if (orderedVal) 1234 result.addAttribute("ordered_val", orderedVal); 1235 if (orderVal) 1236 result.addAttribute("order", orderVal); 1237 if (inclusive) 1238 result.addAttribute("inclusive", inclusive); 1239 result.addAttribute( 1240 WsLoopOp::getOperandSegmentSizeAttr(), 1241 builder.getI32VectorAttr( 1242 {static_cast<int32_t>(lowerBounds.size()), 1243 static_cast<int32_t>(upperBounds.size()), 1244 static_cast<int32_t>(steps.size()), 1245 static_cast<int32_t>(privateVars.size()), 1246 static_cast<int32_t>(firstprivateVars.size()), 1247 static_cast<int32_t>(lastprivateVars.size()), 1248 static_cast<int32_t>(linearVars.size()), 1249 static_cast<int32_t>(linearStepVars.size()), 1250 static_cast<int32_t>(reductionVars.size()), 1251 static_cast<int32_t>(scheduleChunkVar != nullptr ? 1 : 0)})); 1252 1253 Region *bodyRegion = result.addRegion(); 1254 if (buildBody) { 1255 OpBuilder::InsertionGuard guard(builder); 1256 unsigned numIVs = steps.size(); 1257 SmallVector<Type, 8> argTypes(numIVs, steps.getType().front()); 1258 builder.createBlock(bodyRegion, {}, argTypes); 1259 } 1260 } 1261 1262 static LogicalResult verifyWsLoopOp(WsLoopOp op) { 1263 return verifyReductionVarList(op, op.reductions(), op.reduction_vars()); 1264 } 1265 1266 //===----------------------------------------------------------------------===// 1267 // Verifier for critical construct (2.17.1) 1268 //===----------------------------------------------------------------------===// 1269 1270 static LogicalResult verifyCriticalDeclareOp(CriticalDeclareOp op) { 1271 return verifySynchronizationHint(op, op.hint()); 1272 } 1273 1274 static LogicalResult verifyCriticalOp(CriticalOp op) { 1275 1276 if (op.nameAttr()) { 1277 auto symbolRef = op.nameAttr().cast<SymbolRefAttr>(); 1278 auto decl = 1279 SymbolTable::lookupNearestSymbolFrom<CriticalDeclareOp>(op, symbolRef); 1280 if (!decl) { 1281 return op.emitOpError() << "expected symbol reference " << symbolRef 1282 << " to point to a critical declaration"; 1283 } 1284 } 1285 1286 return success(); 1287 } 1288 1289 //===----------------------------------------------------------------------===// 1290 // Verifier for ordered construct 1291 //===----------------------------------------------------------------------===// 1292 1293 static LogicalResult verifyOrderedOp(OrderedOp op) { 1294 auto container = op->getParentOfType<WsLoopOp>(); 1295 if (!container || !container.ordered_valAttr() || 1296 container.ordered_valAttr().getInt() == 0) 1297 return op.emitOpError() << "ordered depend directive must be closely " 1298 << "nested inside a worksharing-loop with ordered " 1299 << "clause with parameter present"; 1300 1301 if (container.ordered_valAttr().getInt() != 1302 (int64_t)op.num_loops_val().getValue()) 1303 return op.emitOpError() << "number of variables in depend clause does not " 1304 << "match number of iteration variables in the " 1305 << "doacross loop"; 1306 1307 return success(); 1308 } 1309 1310 static LogicalResult verifyOrderedRegionOp(OrderedRegionOp op) { 1311 // TODO: The code generation for ordered simd directive is not supported yet. 1312 if (op.simd()) 1313 return failure(); 1314 1315 if (auto container = op->getParentOfType<WsLoopOp>()) { 1316 if (!container.ordered_valAttr() || 1317 container.ordered_valAttr().getInt() != 0) 1318 return op.emitOpError() << "ordered region must be closely nested inside " 1319 << "a worksharing-loop region with an ordered " 1320 << "clause without parameter present"; 1321 } 1322 1323 return success(); 1324 } 1325 1326 //===----------------------------------------------------------------------===// 1327 // AtomicReadOp 1328 //===----------------------------------------------------------------------===// 1329 1330 /// Parser for AtomicReadOp 1331 /// 1332 /// operation ::= `omp.atomic.read` atomic-clause-list address `->` result-type 1333 /// address ::= operand `:` type 1334 static ParseResult parseAtomicReadOp(OpAsmParser &parser, 1335 OperationState &result) { 1336 OpAsmParser::OperandType address; 1337 Type addressType; 1338 SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause}; 1339 SmallVector<int> segments; 1340 1341 if (parser.parseOperand(address) || 1342 parseClauses(parser, result, clauses, segments) || 1343 parser.parseColonType(addressType) || 1344 parser.resolveOperand(address, addressType, result.operands)) 1345 return failure(); 1346 1347 SmallVector<Type> resultType; 1348 if (parser.parseArrowTypeList(resultType)) 1349 return failure(); 1350 result.addTypes(resultType); 1351 return success(); 1352 } 1353 1354 /// Printer for AtomicReadOp 1355 static void printAtomicReadOp(OpAsmPrinter &p, AtomicReadOp op) { 1356 p << " " << op.address() << " "; 1357 if (op.memory_order()) 1358 p << "memory_order(" << op.memory_order().getValue() << ") "; 1359 if (op.hintAttr()) 1360 printSynchronizationHint(p << " ", op, op.hintAttr()); 1361 p << ": " << op.address().getType() << " -> " << op.getType(); 1362 return; 1363 } 1364 1365 /// Verifier for AtomicReadOp 1366 static LogicalResult verifyAtomicReadOp(AtomicReadOp op) { 1367 if (op.memory_order()) { 1368 StringRef memOrder = op.memory_order().getValue(); 1369 if (memOrder.equals("acq_rel") || memOrder.equals("release")) 1370 return op.emitError( 1371 "memory-order must not be acq_rel or release for atomic reads"); 1372 } 1373 return verifySynchronizationHint(op, op.hint()); 1374 } 1375 1376 //===----------------------------------------------------------------------===// 1377 // AtomicWriteOp 1378 //===----------------------------------------------------------------------===// 1379 1380 /// Parser for AtomicWriteOp 1381 /// 1382 /// operation ::= `omp.atomic.write` atomic-clause-list operands 1383 /// operands ::= address `,` value 1384 /// address ::= operand `:` type 1385 /// value ::= operand `:` type 1386 static ParseResult parseAtomicWriteOp(OpAsmParser &parser, 1387 OperationState &result) { 1388 OpAsmParser::OperandType address, value; 1389 Type addrType, valueType; 1390 SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause}; 1391 SmallVector<int> segments; 1392 1393 if (parser.parseOperand(address) || parser.parseComma() || 1394 parser.parseOperand(value) || 1395 parseClauses(parser, result, clauses, segments) || 1396 parser.parseColonType(addrType) || parser.parseComma() || 1397 parser.parseType(valueType) || 1398 parser.resolveOperand(address, addrType, result.operands) || 1399 parser.resolveOperand(value, valueType, result.operands)) 1400 return failure(); 1401 return success(); 1402 } 1403 1404 /// Printer for AtomicWriteOp 1405 static void printAtomicWriteOp(OpAsmPrinter &p, AtomicWriteOp op) { 1406 p << " " << op.address() << ", " << op.value() << " "; 1407 if (op.memory_order()) 1408 p << "memory_order(" << op.memory_order() << ") "; 1409 if (op.hintAttr()) 1410 printSynchronizationHint(p, op, op.hintAttr()); 1411 p << ": " << op.address().getType() << ", " << op.value().getType(); 1412 return; 1413 } 1414 1415 /// Verifier for AtomicWriteOp 1416 static LogicalResult verifyAtomicWriteOp(AtomicWriteOp op) { 1417 if (op.memory_order()) { 1418 StringRef memoryOrder = op.memory_order().getValue(); 1419 if (memoryOrder.equals("acq_rel") || memoryOrder.equals("acquire")) 1420 return op.emitError( 1421 "memory-order must not be acq_rel or acquire for atomic writes"); 1422 } 1423 return verifySynchronizationHint(op, op.hint()); 1424 } 1425 1426 #define GET_OP_CLASSES 1427 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 1428