1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the OpenMP dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/IR/Attributes.h" 17 #include "mlir/IR/OpImplementation.h" 18 #include "mlir/IR/OperationSupport.h" 19 20 #include "llvm/ADT/BitVector.h" 21 #include "llvm/ADT/SmallString.h" 22 #include "llvm/ADT/StringExtras.h" 23 #include "llvm/ADT/StringRef.h" 24 #include "llvm/ADT/StringSwitch.h" 25 #include <cstddef> 26 27 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc" 28 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc" 29 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc" 30 31 using namespace mlir; 32 using namespace mlir::omp; 33 34 namespace { 35 /// Model for pointer-like types that already provide a `getElementType` method. 36 template <typename T> 37 struct PointerLikeModel 38 : public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> { 39 Type getElementType(Type pointer) const { 40 return pointer.cast<T>().getElementType(); 41 } 42 }; 43 } // namespace 44 45 void OpenMPDialect::initialize() { 46 addOperations< 47 #define GET_OP_LIST 48 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 49 >(); 50 51 LLVM::LLVMPointerType::attachInterface< 52 PointerLikeModel<LLVM::LLVMPointerType>>(*getContext()); 53 MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext()); 54 } 55 56 //===----------------------------------------------------------------------===// 57 // ParallelOp 58 //===----------------------------------------------------------------------===// 59 60 void ParallelOp::build(OpBuilder &builder, OperationState &state, 61 ArrayRef<NamedAttribute> attributes) { 62 ParallelOp::build( 63 builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr, 64 /*default_val=*/nullptr, /*private_vars=*/ValueRange(), 65 /*firstprivate_vars=*/ValueRange(), /*shared_vars=*/ValueRange(), 66 /*copyin_vars=*/ValueRange(), /*allocate_vars=*/ValueRange(), 67 /*allocators_vars=*/ValueRange(), /*proc_bind_val=*/nullptr); 68 state.addAttributes(attributes); 69 } 70 71 //===----------------------------------------------------------------------===// 72 // Parser and printer for Operand and type list 73 //===----------------------------------------------------------------------===// 74 75 /// Parse a list of operands with types. 76 /// 77 /// operand-and-type-list ::= `(` ssa-id-and-type-list `)` 78 /// ssa-id-and-type-list ::= ssa-id-and-type | 79 /// ssa-id-and-type `,` ssa-id-and-type-list 80 /// ssa-id-and-type ::= ssa-id `:` type 81 static ParseResult 82 parseOperandAndTypeList(OpAsmParser &parser, 83 SmallVectorImpl<OpAsmParser::OperandType> &operands, 84 SmallVectorImpl<Type> &types) { 85 return parser.parseCommaSeparatedList( 86 OpAsmParser::Delimiter::Paren, [&]() -> ParseResult { 87 OpAsmParser::OperandType operand; 88 Type type; 89 if (parser.parseOperand(operand) || parser.parseColonType(type)) 90 return failure(); 91 operands.push_back(operand); 92 types.push_back(type); 93 return success(); 94 }); 95 } 96 97 /// Print an operand and type list with parentheses 98 static void printOperandAndTypeList(OpAsmPrinter &p, OperandRange operands) { 99 p << "("; 100 llvm::interleaveComma( 101 operands, p, [&](const Value &v) { p << v << " : " << v.getType(); }); 102 p << ") "; 103 } 104 105 /// Print data variables corresponding to a data-sharing clause `name` 106 static void printDataVars(OpAsmPrinter &p, OperandRange operands, 107 StringRef name) { 108 if (!operands.empty()) { 109 p << name; 110 printOperandAndTypeList(p, operands); 111 } 112 } 113 114 //===----------------------------------------------------------------------===// 115 // Parser and printer for Allocate Clause 116 //===----------------------------------------------------------------------===// 117 118 /// Parse an allocate clause with allocators and a list of operands with types. 119 /// 120 /// allocate ::= `allocate` `(` allocate-operand-list `)` 121 /// allocate-operand-list :: = allocate-operand | 122 /// allocator-operand `,` allocate-operand-list 123 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type 124 /// ssa-id-and-type ::= ssa-id `:` type 125 static ParseResult parseAllocateAndAllocator( 126 OpAsmParser &parser, 127 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate, 128 SmallVectorImpl<Type> &typesAllocate, 129 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator, 130 SmallVectorImpl<Type> &typesAllocator) { 131 132 return parser.parseCommaSeparatedList( 133 OpAsmParser::Delimiter::Paren, [&]() -> ParseResult { 134 OpAsmParser::OperandType operand; 135 Type type; 136 if (parser.parseOperand(operand) || parser.parseColonType(type)) 137 return failure(); 138 operandsAllocator.push_back(operand); 139 typesAllocator.push_back(type); 140 if (parser.parseArrow()) 141 return failure(); 142 if (parser.parseOperand(operand) || parser.parseColonType(type)) 143 return failure(); 144 145 operandsAllocate.push_back(operand); 146 typesAllocate.push_back(type); 147 return success(); 148 }); 149 } 150 151 /// Print allocate clause 152 static void printAllocateAndAllocator(OpAsmPrinter &p, 153 OperandRange varsAllocate, 154 OperandRange varsAllocator) { 155 p << "allocate("; 156 for (unsigned i = 0; i < varsAllocate.size(); ++i) { 157 std::string separator = i == varsAllocate.size() - 1 ? ") " : ", "; 158 p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> "; 159 p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator; 160 } 161 } 162 163 static LogicalResult verifyParallelOp(ParallelOp op) { 164 if (op.allocate_vars().size() != op.allocators_vars().size()) 165 return op.emitError( 166 "expected equal sizes for allocate and allocator variables"); 167 return success(); 168 } 169 170 static void printParallelOp(OpAsmPrinter &p, ParallelOp op) { 171 p << " "; 172 if (auto ifCond = op.if_expr_var()) 173 p << "if(" << ifCond << " : " << ifCond.getType() << ") "; 174 175 if (auto threads = op.num_threads_var()) 176 p << "num_threads(" << threads << " : " << threads.getType() << ") "; 177 178 printDataVars(p, op.private_vars(), "private"); 179 printDataVars(p, op.firstprivate_vars(), "firstprivate"); 180 printDataVars(p, op.shared_vars(), "shared"); 181 printDataVars(p, op.copyin_vars(), "copyin"); 182 183 if (!op.allocate_vars().empty()) 184 printAllocateAndAllocator(p, op.allocate_vars(), op.allocators_vars()); 185 186 if (auto def = op.default_val()) 187 p << "default(" << def->drop_front(3) << ") "; 188 189 if (auto bind = op.proc_bind_val()) 190 p << "proc_bind(" << bind << ") "; 191 192 p << ' '; 193 p.printRegion(op.getRegion()); 194 } 195 196 //===----------------------------------------------------------------------===// 197 // Parser and printer for Linear Clause 198 //===----------------------------------------------------------------------===// 199 200 /// linear ::= `linear` `(` linear-list `)` 201 /// linear-list := linear-val | linear-val linear-list 202 /// linear-val := ssa-id-and-type `=` ssa-id-and-type 203 static ParseResult 204 parseLinearClause(OpAsmParser &parser, 205 SmallVectorImpl<OpAsmParser::OperandType> &vars, 206 SmallVectorImpl<Type> &types, 207 SmallVectorImpl<OpAsmParser::OperandType> &stepVars) { 208 if (parser.parseLParen()) 209 return failure(); 210 211 do { 212 OpAsmParser::OperandType var; 213 Type type; 214 OpAsmParser::OperandType stepVar; 215 if (parser.parseOperand(var) || parser.parseEqual() || 216 parser.parseOperand(stepVar) || parser.parseColonType(type)) 217 return failure(); 218 219 vars.push_back(var); 220 types.push_back(type); 221 stepVars.push_back(stepVar); 222 } while (succeeded(parser.parseOptionalComma())); 223 224 if (parser.parseRParen()) 225 return failure(); 226 227 return success(); 228 } 229 230 /// Print Linear Clause 231 static void printLinearClause(OpAsmPrinter &p, OperandRange linearVars, 232 OperandRange linearStepVars) { 233 size_t linearVarsSize = linearVars.size(); 234 p << "linear("; 235 for (unsigned i = 0; i < linearVarsSize; ++i) { 236 std::string separator = i == linearVarsSize - 1 ? ") " : ", "; 237 p << linearVars[i]; 238 if (linearStepVars.size() > i) 239 p << " = " << linearStepVars[i]; 240 p << " : " << linearVars[i].getType() << separator; 241 } 242 } 243 244 //===----------------------------------------------------------------------===// 245 // Parser and printer for Schedule Clause 246 //===----------------------------------------------------------------------===// 247 248 static ParseResult 249 verifyScheduleModifiers(OpAsmParser &parser, 250 SmallVectorImpl<SmallString<12>> &modifiers) { 251 if (modifiers.size() > 2) 252 return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)"; 253 for (const auto &mod : modifiers) { 254 // Translate the string. If it has no value, then it was not a valid 255 // modifier! 256 auto symbol = symbolizeScheduleModifier(mod); 257 if (!symbol.hasValue()) 258 return parser.emitError(parser.getNameLoc()) 259 << " unknown modifier type: " << mod; 260 } 261 262 // If we have one modifier that is "simd", then stick a "none" modiifer in 263 // index 0. 264 if (modifiers.size() == 1) { 265 if (symbolizeScheduleModifier(modifiers[0]) == 266 mlir::omp::ScheduleModifier::simd) { 267 modifiers.push_back(modifiers[0]); 268 modifiers[0] = 269 stringifyScheduleModifier(mlir::omp::ScheduleModifier::none); 270 } 271 } else if (modifiers.size() == 2) { 272 // If there are two modifier: 273 // First modifier should not be simd, second one should be simd 274 if (symbolizeScheduleModifier(modifiers[0]) == 275 mlir::omp::ScheduleModifier::simd || 276 symbolizeScheduleModifier(modifiers[1]) != 277 mlir::omp::ScheduleModifier::simd) 278 return parser.emitError(parser.getNameLoc()) 279 << " incorrect modifier order"; 280 } 281 return success(); 282 } 283 284 /// schedule ::= `schedule` `(` sched-list `)` 285 /// sched-list ::= sched-val | sched-val sched-list | 286 /// sched-val `,` sched-modifier 287 /// sched-val ::= sched-with-chunk | sched-wo-chunk 288 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)? 289 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided` 290 /// sched-wo-chunk ::= `auto` | `runtime` 291 /// sched-modifier ::= sched-mod-val | sched-mod-val `,` sched-mod-val 292 /// sched-mod-val ::= `monotonic` | `nonmonotonic` | `simd` | `none` 293 static ParseResult 294 parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule, 295 SmallVectorImpl<SmallString<12>> &modifiers, 296 Optional<OpAsmParser::OperandType> &chunkSize) { 297 if (parser.parseLParen()) 298 return failure(); 299 300 StringRef keyword; 301 if (parser.parseKeyword(&keyword)) 302 return failure(); 303 304 schedule = keyword; 305 if (keyword == "static" || keyword == "dynamic" || keyword == "guided") { 306 if (succeeded(parser.parseOptionalEqual())) { 307 chunkSize = OpAsmParser::OperandType{}; 308 if (parser.parseOperand(*chunkSize)) 309 return failure(); 310 } else { 311 chunkSize = llvm::NoneType::None; 312 } 313 } else if (keyword == "auto" || keyword == "runtime") { 314 chunkSize = llvm::NoneType::None; 315 } else { 316 return parser.emitError(parser.getNameLoc()) << " expected schedule kind"; 317 } 318 319 // If there is a comma, we have one or more modifiers.. 320 while (succeeded(parser.parseOptionalComma())) { 321 StringRef mod; 322 if (parser.parseKeyword(&mod)) 323 return failure(); 324 modifiers.push_back(mod); 325 } 326 327 if (parser.parseRParen()) 328 return failure(); 329 330 if (verifyScheduleModifiers(parser, modifiers)) 331 return failure(); 332 333 return success(); 334 } 335 336 /// Print schedule clause 337 static void printScheduleClause(OpAsmPrinter &p, StringRef &sched, 338 llvm::Optional<StringRef> modifier, bool simd, 339 Value scheduleChunkVar) { 340 std::string schedLower = sched.lower(); 341 p << "schedule(" << schedLower; 342 if (scheduleChunkVar) 343 p << " = " << scheduleChunkVar; 344 if (modifier && modifier.hasValue()) 345 p << ", " << modifier; 346 if (simd) 347 p << ", simd"; 348 p << ") "; 349 } 350 351 //===----------------------------------------------------------------------===// 352 // Parser, printer and verifier for ReductionVarList 353 //===----------------------------------------------------------------------===// 354 355 /// reduction ::= `reduction` `(` reduction-entry-list `)` 356 /// reduction-entry-list ::= reduction-entry 357 /// | reduction-entry-list `,` reduction-entry 358 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type 359 static ParseResult 360 parseReductionVarList(OpAsmParser &parser, 361 SmallVectorImpl<SymbolRefAttr> &symbols, 362 SmallVectorImpl<OpAsmParser::OperandType> &operands, 363 SmallVectorImpl<Type> &types) { 364 if (failed(parser.parseLParen())) 365 return failure(); 366 367 do { 368 if (parser.parseAttribute(symbols.emplace_back()) || parser.parseArrow() || 369 parser.parseOperand(operands.emplace_back()) || 370 parser.parseColonType(types.emplace_back())) 371 return failure(); 372 } while (succeeded(parser.parseOptionalComma())); 373 return parser.parseRParen(); 374 } 375 376 /// Print Reduction clause 377 static void printReductionVarList(OpAsmPrinter &p, 378 Optional<ArrayAttr> reductions, 379 OperandRange reductionVars) { 380 p << "reduction("; 381 for (unsigned i = 0, e = reductions->size(); i < e; ++i) { 382 if (i != 0) 383 p << ", "; 384 p << (*reductions)[i] << " -> " << reductionVars[i] << " : " 385 << reductionVars[i].getType(); 386 } 387 p << ") "; 388 } 389 390 /// Verifies Reduction Clause 391 static LogicalResult verifyReductionVarList(Operation *op, 392 Optional<ArrayAttr> reductions, 393 OperandRange reductionVars) { 394 if (!reductionVars.empty()) { 395 if (!reductions || reductions->size() != reductionVars.size()) 396 return op->emitOpError() 397 << "expected as many reduction symbol references " 398 "as reduction variables"; 399 } else { 400 if (reductions) 401 return op->emitOpError() << "unexpected reduction symbol references"; 402 return success(); 403 } 404 405 DenseSet<Value> accumulators; 406 for (auto args : llvm::zip(reductionVars, *reductions)) { 407 Value accum = std::get<0>(args); 408 409 if (!accumulators.insert(accum).second) 410 return op->emitOpError() << "accumulator variable used more than once"; 411 412 Type varType = accum.getType().cast<PointerLikeType>(); 413 auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>(); 414 auto decl = 415 SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef); 416 if (!decl) 417 return op->emitOpError() << "expected symbol reference " << symbolRef 418 << " to point to a reduction declaration"; 419 420 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType) 421 return op->emitOpError() 422 << "expected accumulator (" << varType 423 << ") to be the same type as reduction declaration (" 424 << decl.getAccumulatorType() << ")"; 425 } 426 427 return success(); 428 } 429 430 //===----------------------------------------------------------------------===// 431 // Parser, printer and verifier for Synchronization Hint (2.17.12) 432 //===----------------------------------------------------------------------===// 433 434 /// Parses a Synchronization Hint clause. The value of hint is an integer 435 /// which is a combination of different hints from `omp_sync_hint_t`. 436 /// 437 /// hint-clause = `hint` `(` hint-value `)` 438 static ParseResult parseSynchronizationHint(OpAsmParser &parser, 439 IntegerAttr &hintAttr, 440 bool parseKeyword = true) { 441 if (parseKeyword && failed(parser.parseOptionalKeyword("hint"))) { 442 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0); 443 return success(); 444 } 445 446 if (failed(parser.parseLParen())) 447 return failure(); 448 StringRef hintKeyword; 449 int64_t hint = 0; 450 do { 451 if (failed(parser.parseKeyword(&hintKeyword))) 452 return failure(); 453 if (hintKeyword == "uncontended") 454 hint |= 1; 455 else if (hintKeyword == "contended") 456 hint |= 2; 457 else if (hintKeyword == "nonspeculative") 458 hint |= 4; 459 else if (hintKeyword == "speculative") 460 hint |= 8; 461 else 462 return parser.emitError(parser.getCurrentLocation()) 463 << hintKeyword << " is not a valid hint"; 464 } while (succeeded(parser.parseOptionalComma())); 465 if (failed(parser.parseRParen())) 466 return failure(); 467 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint); 468 return success(); 469 } 470 471 /// Prints a Synchronization Hint clause 472 static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, 473 IntegerAttr hintAttr) { 474 int64_t hint = hintAttr.getInt(); 475 476 if (hint == 0) 477 return; 478 479 // Helper function to get n-th bit from the right end of `value` 480 auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; 481 482 bool uncontended = bitn(hint, 0); 483 bool contended = bitn(hint, 1); 484 bool nonspeculative = bitn(hint, 2); 485 bool speculative = bitn(hint, 3); 486 487 SmallVector<StringRef> hints; 488 if (uncontended) 489 hints.push_back("uncontended"); 490 if (contended) 491 hints.push_back("contended"); 492 if (nonspeculative) 493 hints.push_back("nonspeculative"); 494 if (speculative) 495 hints.push_back("speculative"); 496 497 p << "hint("; 498 llvm::interleaveComma(hints, p); 499 p << ") "; 500 } 501 502 /// Verifies a synchronization hint clause 503 static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) { 504 505 // Helper function to get n-th bit from the right end of `value` 506 auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; 507 508 bool uncontended = bitn(hint, 0); 509 bool contended = bitn(hint, 1); 510 bool nonspeculative = bitn(hint, 2); 511 bool speculative = bitn(hint, 3); 512 513 if (uncontended && contended) 514 return op->emitOpError() << "the hints omp_sync_hint_uncontended and " 515 "omp_sync_hint_contended cannot be combined"; 516 if (nonspeculative && speculative) 517 return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and " 518 "omp_sync_hint_speculative cannot be combined."; 519 return success(); 520 } 521 522 enum ClauseType { 523 ifClause, 524 numThreadsClause, 525 privateClause, 526 firstprivateClause, 527 lastprivateClause, 528 sharedClause, 529 copyinClause, 530 allocateClause, 531 defaultClause, 532 procBindClause, 533 reductionClause, 534 nowaitClause, 535 linearClause, 536 scheduleClause, 537 collapseClause, 538 orderClause, 539 orderedClause, 540 memoryOrderClause, 541 hintClause, 542 COUNT 543 }; 544 545 //===----------------------------------------------------------------------===// 546 // Parser for Clause List 547 //===----------------------------------------------------------------------===// 548 549 /// Parse a list of clauses. The clauses can appear in any order, but their 550 /// operand segment indices are in the same order that they are passed in the 551 /// `clauses` list. The operand segments are added over the prevSegments 552 553 /// clause-list ::= clause clause-list | empty 554 /// clause ::= if | num-threads | private | firstprivate | lastprivate | 555 /// shared | copyin | allocate | default | proc-bind | reduction | 556 /// nowait | linear | schedule | collapse | order | ordered | 557 /// inclusive 558 /// if ::= `if` `(` ssa-id-and-type `)` 559 /// num-threads ::= `num_threads` `(` ssa-id-and-type `)` 560 /// private ::= `private` operand-and-type-list 561 /// firstprivate ::= `firstprivate` operand-and-type-list 562 /// lastprivate ::= `lastprivate` operand-and-type-list 563 /// shared ::= `shared` operand-and-type-list 564 /// copyin ::= `copyin` operand-and-type-list 565 /// allocate ::= `allocate` `(` allocate-operand-list `)` 566 /// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`) 567 /// proc-bind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)` 568 /// reduction ::= `reduction` `(` reduction-entry-list `)` 569 /// nowait ::= `nowait` 570 /// linear ::= `linear` `(` linear-list `)` 571 /// schedule ::= `schedule` `(` sched-list `)` 572 /// collapse ::= `collapse` `(` ssa-id-and-type `)` 573 /// order ::= `order` `(` `concurrent` `)` 574 /// ordered ::= `ordered` `(` ssa-id-and-type `)` 575 /// inclusive ::= `inclusive` 576 /// 577 /// Note that each clause can only appear once in the clase-list. 578 static ParseResult parseClauses(OpAsmParser &parser, OperationState &result, 579 SmallVectorImpl<ClauseType> &clauses, 580 SmallVectorImpl<int> &segments) { 581 582 // Check done[clause] to see if it has been parsed already 583 llvm::BitVector done(ClauseType::COUNT, false); 584 585 // See pos[clause] to get position of clause in operand segments 586 SmallVector<int> pos(ClauseType::COUNT, -1); 587 588 // Stores the last parsed clause keyword 589 StringRef clauseKeyword; 590 StringRef opName = result.name.getStringRef(); 591 592 // Containers for storing operands, types and attributes for various clauses 593 std::pair<OpAsmParser::OperandType, Type> ifCond; 594 std::pair<OpAsmParser::OperandType, Type> numThreads; 595 596 SmallVector<OpAsmParser::OperandType> privates, firstprivates, lastprivates, 597 shareds, copyins; 598 SmallVector<Type> privateTypes, firstprivateTypes, lastprivateTypes, 599 sharedTypes, copyinTypes; 600 601 SmallVector<OpAsmParser::OperandType> allocates, allocators; 602 SmallVector<Type> allocateTypes, allocatorTypes; 603 604 SmallVector<SymbolRefAttr> reductionSymbols; 605 SmallVector<OpAsmParser::OperandType> reductionVars; 606 SmallVector<Type> reductionVarTypes; 607 608 SmallVector<OpAsmParser::OperandType> linears; 609 SmallVector<Type> linearTypes; 610 SmallVector<OpAsmParser::OperandType> linearSteps; 611 612 SmallString<8> schedule; 613 SmallVector<SmallString<12>> modifiers; 614 Optional<OpAsmParser::OperandType> scheduleChunkSize; 615 616 // Compute the position of clauses in operand segments 617 int currPos = 0; 618 for (ClauseType clause : clauses) { 619 620 // Skip the following clauses - they do not take any position in operand 621 // segments 622 if (clause == defaultClause || clause == procBindClause || 623 clause == nowaitClause || clause == collapseClause || 624 clause == orderClause || clause == orderedClause) 625 continue; 626 627 pos[clause] = currPos++; 628 629 // For the following clauses, two positions are reserved in the operand 630 // segments 631 if (clause == allocateClause || clause == linearClause) 632 currPos++; 633 } 634 635 SmallVector<int> clauseSegments(currPos); 636 637 // Helper function to check if a clause is allowed/repeated or not 638 auto checkAllowed = [&](ClauseType clause, 639 bool allowRepeat = false) -> ParseResult { 640 if (!llvm::is_contained(clauses, clause)) 641 return parser.emitError(parser.getCurrentLocation()) 642 << clauseKeyword << " is not a valid clause for the " << opName 643 << " operation"; 644 if (done[clause] && !allowRepeat) 645 return parser.emitError(parser.getCurrentLocation()) 646 << "at most one " << clauseKeyword << " clause can appear on the " 647 << opName << " operation"; 648 done[clause] = true; 649 return success(); 650 }; 651 652 while (succeeded(parser.parseOptionalKeyword(&clauseKeyword))) { 653 if (clauseKeyword == "if") { 654 if (checkAllowed(ifClause) || parser.parseLParen() || 655 parser.parseOperand(ifCond.first) || 656 parser.parseColonType(ifCond.second) || parser.parseRParen()) 657 return failure(); 658 clauseSegments[pos[ifClause]] = 1; 659 } else if (clauseKeyword == "num_threads") { 660 if (checkAllowed(numThreadsClause) || parser.parseLParen() || 661 parser.parseOperand(numThreads.first) || 662 parser.parseColonType(numThreads.second) || parser.parseRParen()) 663 return failure(); 664 clauseSegments[pos[numThreadsClause]] = 1; 665 } else if (clauseKeyword == "private") { 666 if (checkAllowed(privateClause) || 667 parseOperandAndTypeList(parser, privates, privateTypes)) 668 return failure(); 669 clauseSegments[pos[privateClause]] = privates.size(); 670 } else if (clauseKeyword == "firstprivate") { 671 if (checkAllowed(firstprivateClause) || 672 parseOperandAndTypeList(parser, firstprivates, firstprivateTypes)) 673 return failure(); 674 clauseSegments[pos[firstprivateClause]] = firstprivates.size(); 675 } else if (clauseKeyword == "lastprivate") { 676 if (checkAllowed(lastprivateClause) || 677 parseOperandAndTypeList(parser, lastprivates, lastprivateTypes)) 678 return failure(); 679 clauseSegments[pos[lastprivateClause]] = lastprivates.size(); 680 } else if (clauseKeyword == "shared") { 681 if (checkAllowed(sharedClause) || 682 parseOperandAndTypeList(parser, shareds, sharedTypes)) 683 return failure(); 684 clauseSegments[pos[sharedClause]] = shareds.size(); 685 } else if (clauseKeyword == "copyin") { 686 if (checkAllowed(copyinClause) || 687 parseOperandAndTypeList(parser, copyins, copyinTypes)) 688 return failure(); 689 clauseSegments[pos[copyinClause]] = copyins.size(); 690 } else if (clauseKeyword == "allocate") { 691 if (checkAllowed(allocateClause) || 692 parseAllocateAndAllocator(parser, allocates, allocateTypes, 693 allocators, allocatorTypes)) 694 return failure(); 695 clauseSegments[pos[allocateClause]] = allocates.size(); 696 clauseSegments[pos[allocateClause] + 1] = allocators.size(); 697 } else if (clauseKeyword == "default") { 698 StringRef defval; 699 if (checkAllowed(defaultClause) || parser.parseLParen() || 700 parser.parseKeyword(&defval) || parser.parseRParen()) 701 return failure(); 702 // The def prefix is required for the attribute as "private" is a keyword 703 // in C++. 704 auto attr = parser.getBuilder().getStringAttr("def" + defval); 705 result.addAttribute("default_val", attr); 706 } else if (clauseKeyword == "proc_bind") { 707 StringRef bind; 708 if (checkAllowed(procBindClause) || parser.parseLParen() || 709 parser.parseKeyword(&bind) || parser.parseRParen()) 710 return failure(); 711 auto attr = parser.getBuilder().getStringAttr(bind); 712 result.addAttribute("proc_bind_val", attr); 713 } else if (clauseKeyword == "reduction") { 714 if (checkAllowed(reductionClause) || 715 parseReductionVarList(parser, reductionSymbols, reductionVars, 716 reductionVarTypes)) 717 return failure(); 718 clauseSegments[pos[reductionClause]] = reductionVars.size(); 719 } else if (clauseKeyword == "nowait") { 720 if (checkAllowed(nowaitClause)) 721 return failure(); 722 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 723 result.addAttribute("nowait", attr); 724 } else if (clauseKeyword == "linear") { 725 if (checkAllowed(linearClause) || 726 parseLinearClause(parser, linears, linearTypes, linearSteps)) 727 return failure(); 728 clauseSegments[pos[linearClause]] = linears.size(); 729 clauseSegments[pos[linearClause] + 1] = linearSteps.size(); 730 } else if (clauseKeyword == "schedule") { 731 if (checkAllowed(scheduleClause) || 732 parseScheduleClause(parser, schedule, modifiers, scheduleChunkSize)) 733 return failure(); 734 if (scheduleChunkSize) { 735 clauseSegments[pos[scheduleClause]] = 1; 736 } 737 } else if (clauseKeyword == "collapse") { 738 auto type = parser.getBuilder().getI64Type(); 739 mlir::IntegerAttr attr; 740 if (checkAllowed(collapseClause) || parser.parseLParen() || 741 parser.parseAttribute(attr, type) || parser.parseRParen()) 742 return failure(); 743 result.addAttribute("collapse_val", attr); 744 } else if (clauseKeyword == "ordered") { 745 mlir::IntegerAttr attr; 746 if (checkAllowed(orderedClause)) 747 return failure(); 748 if (succeeded(parser.parseOptionalLParen())) { 749 auto type = parser.getBuilder().getI64Type(); 750 if (parser.parseAttribute(attr, type) || parser.parseRParen()) 751 return failure(); 752 } else { 753 // Use 0 to represent no ordered parameter was specified 754 attr = parser.getBuilder().getI64IntegerAttr(0); 755 } 756 result.addAttribute("ordered_val", attr); 757 } else if (clauseKeyword == "order") { 758 StringRef order; 759 if (checkAllowed(orderClause) || parser.parseLParen() || 760 parser.parseKeyword(&order) || parser.parseRParen()) 761 return failure(); 762 auto attr = parser.getBuilder().getStringAttr(order); 763 result.addAttribute("order_val", attr); 764 } else if (clauseKeyword == "memory_order") { 765 StringRef memoryOrder; 766 if (checkAllowed(memoryOrderClause) || parser.parseLParen() || 767 parser.parseKeyword(&memoryOrder) || parser.parseRParen()) 768 return failure(); 769 result.addAttribute("memory_order", 770 parser.getBuilder().getStringAttr(memoryOrder)); 771 } else if (clauseKeyword == "hint") { 772 IntegerAttr hint; 773 if (checkAllowed(hintClause) || 774 parseSynchronizationHint(parser, hint, false)) 775 return failure(); 776 result.addAttribute("hint", hint); 777 } else { 778 return parser.emitError(parser.getNameLoc()) 779 << clauseKeyword << " is not a valid clause"; 780 } 781 } 782 783 // Add if parameter. 784 if (done[ifClause] && clauseSegments[pos[ifClause]] && 785 failed( 786 parser.resolveOperand(ifCond.first, ifCond.second, result.operands))) 787 return failure(); 788 789 // Add num_threads parameter. 790 if (done[numThreadsClause] && clauseSegments[pos[numThreadsClause]] && 791 failed(parser.resolveOperand(numThreads.first, numThreads.second, 792 result.operands))) 793 return failure(); 794 795 // Add private parameters. 796 if (done[privateClause] && clauseSegments[pos[privateClause]] && 797 failed(parser.resolveOperands(privates, privateTypes, 798 privates[0].location, result.operands))) 799 return failure(); 800 801 // Add firstprivate parameters. 802 if (done[firstprivateClause] && clauseSegments[pos[firstprivateClause]] && 803 failed(parser.resolveOperands(firstprivates, firstprivateTypes, 804 firstprivates[0].location, 805 result.operands))) 806 return failure(); 807 808 // Add lastprivate parameters. 809 if (done[lastprivateClause] && clauseSegments[pos[lastprivateClause]] && 810 failed(parser.resolveOperands(lastprivates, lastprivateTypes, 811 lastprivates[0].location, result.operands))) 812 return failure(); 813 814 // Add shared parameters. 815 if (done[sharedClause] && clauseSegments[pos[sharedClause]] && 816 failed(parser.resolveOperands(shareds, sharedTypes, shareds[0].location, 817 result.operands))) 818 return failure(); 819 820 // Add copyin parameters. 821 if (done[copyinClause] && clauseSegments[pos[copyinClause]] && 822 failed(parser.resolveOperands(copyins, copyinTypes, copyins[0].location, 823 result.operands))) 824 return failure(); 825 826 // Add allocate parameters. 827 if (done[allocateClause] && clauseSegments[pos[allocateClause]] && 828 failed(parser.resolveOperands(allocates, allocateTypes, 829 allocates[0].location, result.operands))) 830 return failure(); 831 832 // Add allocator parameters. 833 if (done[allocateClause] && clauseSegments[pos[allocateClause] + 1] && 834 failed(parser.resolveOperands(allocators, allocatorTypes, 835 allocators[0].location, result.operands))) 836 return failure(); 837 838 // Add reduction parameters and symbols 839 if (done[reductionClause] && clauseSegments[pos[reductionClause]]) { 840 if (failed(parser.resolveOperands(reductionVars, reductionVarTypes, 841 parser.getNameLoc(), result.operands))) 842 return failure(); 843 844 SmallVector<Attribute> reductions(reductionSymbols.begin(), 845 reductionSymbols.end()); 846 result.addAttribute("reductions", 847 parser.getBuilder().getArrayAttr(reductions)); 848 } 849 850 // Add linear parameters 851 if (done[linearClause] && clauseSegments[pos[linearClause]]) { 852 auto linearStepType = parser.getBuilder().getI32Type(); 853 SmallVector<Type> linearStepTypes(linearSteps.size(), linearStepType); 854 if (failed(parser.resolveOperands(linears, linearTypes, linears[0].location, 855 result.operands)) || 856 failed(parser.resolveOperands(linearSteps, linearStepTypes, 857 linearSteps[0].location, 858 result.operands))) 859 return failure(); 860 } 861 862 // Add schedule parameters 863 if (done[scheduleClause] && !schedule.empty()) { 864 schedule[0] = llvm::toUpper(schedule[0]); 865 auto attr = parser.getBuilder().getStringAttr(schedule); 866 result.addAttribute("schedule_val", attr); 867 if (!modifiers.empty()) { 868 auto mod = parser.getBuilder().getStringAttr(modifiers[0]); 869 result.addAttribute("schedule_modifier", mod); 870 // Only SIMD attribute is allowed here! 871 if (modifiers.size() > 1) { 872 assert(symbolizeScheduleModifier(modifiers[1]) == 873 mlir::omp::ScheduleModifier::simd); 874 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 875 result.addAttribute("simd_modifier", attr); 876 } 877 } 878 if (scheduleChunkSize) { 879 auto chunkSizeType = parser.getBuilder().getI32Type(); 880 parser.resolveOperand(*scheduleChunkSize, chunkSizeType, result.operands); 881 } 882 } 883 884 segments.insert(segments.end(), clauseSegments.begin(), clauseSegments.end()); 885 886 return success(); 887 } 888 889 /// Parses a parallel operation. 890 /// 891 /// operation ::= `omp.parallel` clause-list 892 /// clause-list ::= clause | clause clause-list 893 /// clause ::= if | num-threads | private | firstprivate | shared | copyin | 894 /// allocate | default | proc-bind 895 /// 896 static ParseResult parseParallelOp(OpAsmParser &parser, 897 OperationState &result) { 898 SmallVector<ClauseType> clauses = { 899 ifClause, numThreadsClause, privateClause, 900 firstprivateClause, sharedClause, copyinClause, 901 allocateClause, defaultClause, procBindClause}; 902 903 SmallVector<int> segments; 904 905 if (failed(parseClauses(parser, result, clauses, segments))) 906 return failure(); 907 908 result.addAttribute("operand_segment_sizes", 909 parser.getBuilder().getI32VectorAttr(segments)); 910 911 Region *body = result.addRegion(); 912 SmallVector<OpAsmParser::OperandType> regionArgs; 913 SmallVector<Type> regionArgTypes; 914 if (parser.parseRegion(*body, regionArgs, regionArgTypes)) 915 return failure(); 916 return success(); 917 } 918 919 //===----------------------------------------------------------------------===// 920 // Parser, printer and verifier for SectionsOp 921 //===----------------------------------------------------------------------===// 922 923 /// Parses an OpenMP Sections operation 924 /// 925 /// sections ::= `omp.sections` clause-list 926 /// clause-list ::= clause clause-list | empty 927 /// clause ::= private | firstprivate | lastprivate | reduction | allocate | 928 /// nowait 929 static ParseResult parseSectionsOp(OpAsmParser &parser, 930 OperationState &result) { 931 932 SmallVector<ClauseType> clauses = {privateClause, firstprivateClause, 933 lastprivateClause, reductionClause, 934 allocateClause, nowaitClause}; 935 936 SmallVector<int> segments; 937 938 if (failed(parseClauses(parser, result, clauses, segments))) 939 return failure(); 940 941 result.addAttribute("operand_segment_sizes", 942 parser.getBuilder().getI32VectorAttr(segments)); 943 944 // Now parse the body. 945 Region *body = result.addRegion(); 946 if (parser.parseRegion(*body)) 947 return failure(); 948 return success(); 949 } 950 951 static void printSectionsOp(OpAsmPrinter &p, SectionsOp op) { 952 p << " "; 953 printDataVars(p, op.private_vars(), "private"); 954 printDataVars(p, op.firstprivate_vars(), "firstprivate"); 955 printDataVars(p, op.lastprivate_vars(), "lastprivate"); 956 957 if (!op.reduction_vars().empty()) 958 printReductionVarList(p, op.reductions(), op.reduction_vars()); 959 960 if (!op.allocate_vars().empty()) 961 printAllocateAndAllocator(p, op.allocate_vars(), op.allocators_vars()); 962 963 if (op.nowait()) 964 p << "nowait"; 965 966 p << ' '; 967 p.printRegion(op.region()); 968 } 969 970 static LogicalResult verifySectionsOp(SectionsOp op) { 971 972 // A list item may not appear in more than one clause on the same directive, 973 // except that it may be specified in both firstprivate and lastprivate 974 // clauses. 975 for (auto var : op.private_vars()) { 976 if (llvm::is_contained(op.firstprivate_vars(), var)) 977 return op.emitOpError() 978 << "operand used in both private and firstprivate clauses"; 979 if (llvm::is_contained(op.lastprivate_vars(), var)) 980 return op.emitOpError() 981 << "operand used in both private and lastprivate clauses"; 982 } 983 984 if (op.allocate_vars().size() != op.allocators_vars().size()) 985 return op.emitError( 986 "expected equal sizes for allocate and allocator variables"); 987 988 for (auto &inst : *op.region().begin()) { 989 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) 990 op.emitOpError() 991 << "expected omp.section op or terminator op inside region"; 992 } 993 994 return verifyReductionVarList(op, op.reductions(), op.reduction_vars()); 995 } 996 997 /// Parses an OpenMP Workshare Loop operation 998 /// 999 /// wsloop ::= `omp.wsloop` loop-control clause-list 1000 /// loop-control ::= `(` ssa-id-list `)` `:` type `=` loop-bounds 1001 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` inclusive? steps 1002 /// steps := `step` `(`ssa-id-list`)` 1003 /// clause-list ::= clause clause-list | empty 1004 /// clause ::= private | firstprivate | lastprivate | linear | schedule | 1005 // collapse | nowait | ordered | order | reduction 1006 static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) { 1007 1008 // Parse an opening `(` followed by induction variables followed by `)` 1009 SmallVector<OpAsmParser::OperandType> ivs; 1010 if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1, 1011 OpAsmParser::Delimiter::Paren)) 1012 return failure(); 1013 1014 int numIVs = static_cast<int>(ivs.size()); 1015 Type loopVarType; 1016 if (parser.parseColonType(loopVarType)) 1017 return failure(); 1018 1019 // Parse loop bounds. 1020 SmallVector<OpAsmParser::OperandType> lower; 1021 if (parser.parseEqual() || 1022 parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) || 1023 parser.resolveOperands(lower, loopVarType, result.operands)) 1024 return failure(); 1025 1026 SmallVector<OpAsmParser::OperandType> upper; 1027 if (parser.parseKeyword("to") || 1028 parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) || 1029 parser.resolveOperands(upper, loopVarType, result.operands)) 1030 return failure(); 1031 1032 if (succeeded(parser.parseOptionalKeyword("inclusive"))) { 1033 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 1034 result.addAttribute("inclusive", attr); 1035 } 1036 1037 // Parse step values. 1038 SmallVector<OpAsmParser::OperandType> steps; 1039 if (parser.parseKeyword("step") || 1040 parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) || 1041 parser.resolveOperands(steps, loopVarType, result.operands)) 1042 return failure(); 1043 1044 SmallVector<ClauseType> clauses = { 1045 privateClause, firstprivateClause, lastprivateClause, linearClause, 1046 reductionClause, collapseClause, orderClause, orderedClause, 1047 nowaitClause, scheduleClause}; 1048 SmallVector<int> segments{numIVs, numIVs, numIVs}; 1049 if (failed(parseClauses(parser, result, clauses, segments))) 1050 return failure(); 1051 1052 result.addAttribute("operand_segment_sizes", 1053 parser.getBuilder().getI32VectorAttr(segments)); 1054 1055 // Now parse the body. 1056 Region *body = result.addRegion(); 1057 SmallVector<Type> ivTypes(numIVs, loopVarType); 1058 SmallVector<OpAsmParser::OperandType> blockArgs(ivs); 1059 if (parser.parseRegion(*body, blockArgs, ivTypes)) 1060 return failure(); 1061 return success(); 1062 } 1063 1064 static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) { 1065 auto args = op.getRegion().front().getArguments(); 1066 p << " (" << args << ") : " << args[0].getType() << " = (" << op.lowerBound() 1067 << ") to (" << op.upperBound() << ") "; 1068 if (op.inclusive()) { 1069 p << "inclusive "; 1070 } 1071 p << "step (" << op.step() << ") "; 1072 1073 printDataVars(p, op.private_vars(), "private"); 1074 printDataVars(p, op.firstprivate_vars(), "firstprivate"); 1075 printDataVars(p, op.lastprivate_vars(), "lastprivate"); 1076 1077 if (!op.linear_vars().empty()) 1078 printLinearClause(p, op.linear_vars(), op.linear_step_vars()); 1079 1080 if (auto sched = op.schedule_val()) 1081 printScheduleClause(p, sched.getValue(), op.schedule_modifier(), 1082 op.simd_modifier(), op.schedule_chunk_var()); 1083 1084 if (auto collapse = op.collapse_val()) 1085 p << "collapse(" << collapse << ") "; 1086 1087 if (op.nowait()) 1088 p << "nowait "; 1089 1090 if (auto ordered = op.ordered_val()) 1091 p << "ordered(" << ordered << ") "; 1092 1093 if (auto order = op.order_val()) 1094 p << "order(" << order << ") "; 1095 1096 if (!op.reduction_vars().empty()) 1097 printReductionVarList(p, op.reductions(), op.reduction_vars()); 1098 1099 p << ' '; 1100 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 1101 } 1102 1103 //===----------------------------------------------------------------------===// 1104 // ReductionOp 1105 //===----------------------------------------------------------------------===// 1106 1107 static ParseResult parseAtomicReductionRegion(OpAsmParser &parser, 1108 Region ®ion) { 1109 if (parser.parseOptionalKeyword("atomic")) 1110 return success(); 1111 return parser.parseRegion(region); 1112 } 1113 1114 static void printAtomicReductionRegion(OpAsmPrinter &printer, 1115 ReductionDeclareOp op, Region ®ion) { 1116 if (region.empty()) 1117 return; 1118 printer << "atomic "; 1119 printer.printRegion(region); 1120 } 1121 1122 static LogicalResult verifyReductionDeclareOp(ReductionDeclareOp op) { 1123 if (op.initializerRegion().empty()) 1124 return op.emitOpError() << "expects non-empty initializer region"; 1125 Block &initializerEntryBlock = op.initializerRegion().front(); 1126 if (initializerEntryBlock.getNumArguments() != 1 || 1127 initializerEntryBlock.getArgument(0).getType() != op.type()) { 1128 return op.emitOpError() << "expects initializer region with one argument " 1129 "of the reduction type"; 1130 } 1131 1132 for (YieldOp yieldOp : op.initializerRegion().getOps<YieldOp>()) { 1133 if (yieldOp.results().size() != 1 || 1134 yieldOp.results().getTypes()[0] != op.type()) 1135 return op.emitOpError() << "expects initializer region to yield a value " 1136 "of the reduction type"; 1137 } 1138 1139 if (op.reductionRegion().empty()) 1140 return op.emitOpError() << "expects non-empty reduction region"; 1141 Block &reductionEntryBlock = op.reductionRegion().front(); 1142 if (reductionEntryBlock.getNumArguments() != 2 || 1143 reductionEntryBlock.getArgumentTypes()[0] != 1144 reductionEntryBlock.getArgumentTypes()[1] || 1145 reductionEntryBlock.getArgumentTypes()[0] != op.type()) 1146 return op.emitOpError() << "expects reduction region with two arguments of " 1147 "the reduction type"; 1148 for (YieldOp yieldOp : op.reductionRegion().getOps<YieldOp>()) { 1149 if (yieldOp.results().size() != 1 || 1150 yieldOp.results().getTypes()[0] != op.type()) 1151 return op.emitOpError() << "expects reduction region to yield a value " 1152 "of the reduction type"; 1153 } 1154 1155 if (op.atomicReductionRegion().empty()) 1156 return success(); 1157 1158 Block &atomicReductionEntryBlock = op.atomicReductionRegion().front(); 1159 if (atomicReductionEntryBlock.getNumArguments() != 2 || 1160 atomicReductionEntryBlock.getArgumentTypes()[0] != 1161 atomicReductionEntryBlock.getArgumentTypes()[1]) 1162 return op.emitOpError() << "expects atomic reduction region with two " 1163 "arguments of the same type"; 1164 auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0] 1165 .dyn_cast<PointerLikeType>(); 1166 if (!ptrType || ptrType.getElementType() != op.type()) 1167 return op.emitOpError() << "expects atomic reduction region arguments to " 1168 "be accumulators containing the reduction type"; 1169 return success(); 1170 } 1171 1172 static LogicalResult verifyReductionOp(ReductionOp op) { 1173 // TODO: generalize this to an op interface when there is more than one op 1174 // that supports reductions. 1175 auto container = op->getParentOfType<WsLoopOp>(); 1176 for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i) 1177 if (container.reduction_vars()[i] == op.accumulator()) 1178 return success(); 1179 1180 return op.emitOpError() << "the accumulator is not used by the parent"; 1181 } 1182 1183 //===----------------------------------------------------------------------===// 1184 // WsLoopOp 1185 //===----------------------------------------------------------------------===// 1186 1187 void WsLoopOp::build(OpBuilder &builder, OperationState &state, 1188 ValueRange lowerBound, ValueRange upperBound, 1189 ValueRange step, ArrayRef<NamedAttribute> attributes) { 1190 build(builder, state, TypeRange(), lowerBound, upperBound, step, 1191 /*privateVars=*/ValueRange(), 1192 /*firstprivateVars=*/ValueRange(), /*lastprivate_vars=*/ValueRange(), 1193 /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(), 1194 /*reduction_vars=*/ValueRange(), /*schedule_val=*/nullptr, 1195 /*schedule_chunk_var=*/nullptr, /*collapse_val=*/nullptr, 1196 /*nowait=*/nullptr, /*ordered_val=*/nullptr, /*order_val=*/nullptr, 1197 /*inclusive=*/nullptr, /*buildBody=*/false); 1198 state.addAttributes(attributes); 1199 } 1200 1201 void WsLoopOp::build(OpBuilder &, OperationState &state, TypeRange resultTypes, 1202 ValueRange operands, ArrayRef<NamedAttribute> attributes) { 1203 state.addOperands(operands); 1204 state.addAttributes(attributes); 1205 (void)state.addRegion(); 1206 assert(resultTypes.empty() && "mismatched number of return types"); 1207 state.addTypes(resultTypes); 1208 } 1209 1210 void WsLoopOp::build(OpBuilder &builder, OperationState &result, 1211 TypeRange typeRange, ValueRange lowerBounds, 1212 ValueRange upperBounds, ValueRange steps, 1213 ValueRange privateVars, ValueRange firstprivateVars, 1214 ValueRange lastprivateVars, ValueRange linearVars, 1215 ValueRange linearStepVars, ValueRange reductionVars, 1216 StringAttr scheduleVal, Value scheduleChunkVar, 1217 IntegerAttr collapseVal, UnitAttr nowait, 1218 IntegerAttr orderedVal, StringAttr orderVal, 1219 UnitAttr inclusive, bool buildBody) { 1220 result.addOperands(lowerBounds); 1221 result.addOperands(upperBounds); 1222 result.addOperands(steps); 1223 result.addOperands(privateVars); 1224 result.addOperands(firstprivateVars); 1225 result.addOperands(linearVars); 1226 result.addOperands(linearStepVars); 1227 if (scheduleChunkVar) 1228 result.addOperands(scheduleChunkVar); 1229 1230 if (scheduleVal) 1231 result.addAttribute("schedule_val", scheduleVal); 1232 if (collapseVal) 1233 result.addAttribute("collapse_val", collapseVal); 1234 if (nowait) 1235 result.addAttribute("nowait", nowait); 1236 if (orderedVal) 1237 result.addAttribute("ordered_val", orderedVal); 1238 if (orderVal) 1239 result.addAttribute("order", orderVal); 1240 if (inclusive) 1241 result.addAttribute("inclusive", inclusive); 1242 result.addAttribute( 1243 WsLoopOp::getOperandSegmentSizeAttr(), 1244 builder.getI32VectorAttr( 1245 {static_cast<int32_t>(lowerBounds.size()), 1246 static_cast<int32_t>(upperBounds.size()), 1247 static_cast<int32_t>(steps.size()), 1248 static_cast<int32_t>(privateVars.size()), 1249 static_cast<int32_t>(firstprivateVars.size()), 1250 static_cast<int32_t>(lastprivateVars.size()), 1251 static_cast<int32_t>(linearVars.size()), 1252 static_cast<int32_t>(linearStepVars.size()), 1253 static_cast<int32_t>(reductionVars.size()), 1254 static_cast<int32_t>(scheduleChunkVar != nullptr ? 1 : 0)})); 1255 1256 Region *bodyRegion = result.addRegion(); 1257 if (buildBody) { 1258 OpBuilder::InsertionGuard guard(builder); 1259 unsigned numIVs = steps.size(); 1260 SmallVector<Type, 8> argTypes(numIVs, steps.getType().front()); 1261 builder.createBlock(bodyRegion, {}, argTypes); 1262 } 1263 } 1264 1265 static LogicalResult verifyWsLoopOp(WsLoopOp op) { 1266 return verifyReductionVarList(op, op.reductions(), op.reduction_vars()); 1267 } 1268 1269 //===----------------------------------------------------------------------===// 1270 // Verifier for critical construct (2.17.1) 1271 //===----------------------------------------------------------------------===// 1272 1273 static LogicalResult verifyCriticalDeclareOp(CriticalDeclareOp op) { 1274 return verifySynchronizationHint(op, op.hint()); 1275 } 1276 1277 static LogicalResult verifyCriticalOp(CriticalOp op) { 1278 1279 if (op.nameAttr()) { 1280 auto symbolRef = op.nameAttr().cast<SymbolRefAttr>(); 1281 auto decl = 1282 SymbolTable::lookupNearestSymbolFrom<CriticalDeclareOp>(op, symbolRef); 1283 if (!decl) { 1284 return op.emitOpError() << "expected symbol reference " << symbolRef 1285 << " to point to a critical declaration"; 1286 } 1287 } 1288 1289 return success(); 1290 } 1291 1292 //===----------------------------------------------------------------------===// 1293 // Verifier for ordered construct 1294 //===----------------------------------------------------------------------===// 1295 1296 static LogicalResult verifyOrderedOp(OrderedOp op) { 1297 auto container = op->getParentOfType<WsLoopOp>(); 1298 if (!container || !container.ordered_valAttr() || 1299 container.ordered_valAttr().getInt() == 0) 1300 return op.emitOpError() << "ordered depend directive must be closely " 1301 << "nested inside a worksharing-loop with ordered " 1302 << "clause with parameter present"; 1303 1304 if (container.ordered_valAttr().getInt() != 1305 (int64_t)op.num_loops_val().getValue()) 1306 return op.emitOpError() << "number of variables in depend clause does not " 1307 << "match number of iteration variables in the " 1308 << "doacross loop"; 1309 1310 return success(); 1311 } 1312 1313 static LogicalResult verifyOrderedRegionOp(OrderedRegionOp op) { 1314 // TODO: The code generation for ordered simd directive is not supported yet. 1315 if (op.simd()) 1316 return failure(); 1317 1318 if (auto container = op->getParentOfType<WsLoopOp>()) { 1319 if (!container.ordered_valAttr() || 1320 container.ordered_valAttr().getInt() != 0) 1321 return op.emitOpError() << "ordered region must be closely nested inside " 1322 << "a worksharing-loop region with an ordered " 1323 << "clause without parameter present"; 1324 } 1325 1326 return success(); 1327 } 1328 1329 //===----------------------------------------------------------------------===// 1330 // AtomicReadOp 1331 //===----------------------------------------------------------------------===// 1332 1333 /// Parser for AtomicReadOp 1334 /// 1335 /// operation ::= `omp.atomic.read` atomic-clause-list address `->` result-type 1336 /// address ::= operand `:` type 1337 static ParseResult parseAtomicReadOp(OpAsmParser &parser, 1338 OperationState &result) { 1339 OpAsmParser::OperandType x, v; 1340 Type addressType; 1341 SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause}; 1342 SmallVector<int> segments; 1343 1344 if (parser.parseOperand(v) || parser.parseEqual() || parser.parseOperand(x) || 1345 parseClauses(parser, result, clauses, segments) || 1346 parser.parseColonType(addressType) || 1347 parser.resolveOperand(x, addressType, result.operands) || 1348 parser.resolveOperand(v, addressType, result.operands)) 1349 return failure(); 1350 1351 return success(); 1352 } 1353 1354 /// Printer for AtomicReadOp 1355 static void printAtomicReadOp(OpAsmPrinter &p, AtomicReadOp op) { 1356 p << " " << op.v() << " = " << op.x() << " "; 1357 if (op.memory_order()) 1358 p << "memory_order(" << op.memory_order().getValue() << ") "; 1359 if (op.hintAttr()) 1360 printSynchronizationHint(p << " ", op, op.hintAttr()); 1361 p << ": " << op.x().getType(); 1362 } 1363 1364 /// Verifier for AtomicReadOp 1365 static LogicalResult verifyAtomicReadOp(AtomicReadOp op) { 1366 if (op.memory_order()) { 1367 StringRef memOrder = op.memory_order().getValue(); 1368 if (memOrder.equals("acq_rel") || memOrder.equals("release")) 1369 return op.emitError( 1370 "memory-order must not be acq_rel or release for atomic reads"); 1371 } 1372 if (op.x() == op.v()) 1373 return op.emitError( 1374 "read and write must not be to the same location for atomic reads"); 1375 return verifySynchronizationHint(op, op.hint()); 1376 } 1377 1378 //===----------------------------------------------------------------------===// 1379 // AtomicWriteOp 1380 //===----------------------------------------------------------------------===// 1381 1382 /// Parser for AtomicWriteOp 1383 /// 1384 /// operation ::= `omp.atomic.write` atomic-clause-list operands 1385 /// operands ::= address `,` value 1386 /// address ::= operand `:` type 1387 /// value ::= operand `:` type 1388 static ParseResult parseAtomicWriteOp(OpAsmParser &parser, 1389 OperationState &result) { 1390 OpAsmParser::OperandType address, value; 1391 Type addrType, valueType; 1392 SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause}; 1393 SmallVector<int> segments; 1394 1395 if (parser.parseOperand(address) || parser.parseEqual() || 1396 parser.parseOperand(value) || 1397 parseClauses(parser, result, clauses, segments) || 1398 parser.parseColonType(addrType) || parser.parseComma() || 1399 parser.parseType(valueType) || 1400 parser.resolveOperand(address, addrType, result.operands) || 1401 parser.resolveOperand(value, valueType, result.operands)) 1402 return failure(); 1403 return success(); 1404 } 1405 1406 /// Printer for AtomicWriteOp 1407 static void printAtomicWriteOp(OpAsmPrinter &p, AtomicWriteOp op) { 1408 p << " " << op.address() << " = " << op.value() << " "; 1409 if (op.memory_order()) 1410 p << "memory_order(" << op.memory_order() << ") "; 1411 if (op.hintAttr()) 1412 printSynchronizationHint(p, op, op.hintAttr()); 1413 p << ": " << op.address().getType() << ", " << op.value().getType(); 1414 } 1415 1416 /// Verifier for AtomicWriteOp 1417 static LogicalResult verifyAtomicWriteOp(AtomicWriteOp op) { 1418 if (op.memory_order()) { 1419 StringRef memoryOrder = op.memory_order().getValue(); 1420 if (memoryOrder.equals("acq_rel") || memoryOrder.equals("acquire")) 1421 return op.emitError( 1422 "memory-order must not be acq_rel or acquire for atomic writes"); 1423 } 1424 return verifySynchronizationHint(op, op.hint()); 1425 } 1426 1427 //===----------------------------------------------------------------------===// 1428 // AtomicUpdateOp 1429 //===----------------------------------------------------------------------===// 1430 1431 /// Parser for AtomicUpdateOp 1432 /// 1433 /// operation ::= `omp.atomic.update` atomic-clause-list region 1434 static ParseResult parseAtomicUpdateOp(OpAsmParser &parser, 1435 OperationState &result) { 1436 SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause}; 1437 SmallVector<int> segments; 1438 OpAsmParser::OperandType x, y, z; 1439 Type xType, exprType; 1440 StringRef binOp; 1441 1442 // x = y `op` z : xtype, exprtype 1443 if (parser.parseOperand(x) || parser.parseEqual() || parser.parseOperand(y) || 1444 parser.parseKeyword(&binOp) || parser.parseOperand(z) || 1445 parseClauses(parser, result, clauses, segments) || parser.parseColon() || 1446 parser.parseType(xType) || parser.parseComma() || 1447 parser.parseType(exprType) || 1448 parser.resolveOperand(x, xType, result.operands)) { 1449 return failure(); 1450 } 1451 1452 auto binOpEnum = AtomicBinOpKindToEnum(binOp.upper()); 1453 if (!binOpEnum) 1454 return parser.emitError(parser.getNameLoc()) 1455 << "invalid atomic bin op in atomic update\n"; 1456 auto attr = 1457 parser.getBuilder().getI64IntegerAttr((int64_t)binOpEnum.getValue()); 1458 result.addAttribute("binop", attr); 1459 1460 OpAsmParser::OperandType expr; 1461 if (x.name == y.name && x.number == y.number) { 1462 expr = z; 1463 result.addAttribute("isXBinopExpr", parser.getBuilder().getUnitAttr()); 1464 } else if (x.name == z.name && x.number == z.number) { 1465 expr = y; 1466 } else { 1467 return parser.emitError(parser.getNameLoc()) 1468 << "atomic update variable " << x.name 1469 << " not found in the RHS of the assignment statement in an" 1470 " atomic.update operation"; 1471 } 1472 return parser.resolveOperand(expr, exprType, result.operands); 1473 } 1474 1475 /// Printer for AtomicUpdateOp 1476 static void printAtomicUpdateOp(OpAsmPrinter &p, AtomicUpdateOp op) { 1477 p << " " << op.x() << " = "; 1478 Value y, z; 1479 if (op.isXBinopExpr()) { 1480 y = op.x(); 1481 z = op.expr(); 1482 } else { 1483 y = op.expr(); 1484 z = op.x(); 1485 } 1486 p << y << " " << AtomicBinOpKindToString(op.binop()).lower() << " " << z 1487 << " "; 1488 if (op.memory_order()) 1489 p << "memory_order(" << op.memory_order() << ") "; 1490 if (op.hintAttr()) 1491 printSynchronizationHint(p, op, op.hintAttr()); 1492 p << ": " << op.x().getType() << ", " << op.expr().getType(); 1493 } 1494 1495 /// Verifier for AtomicUpdateOp 1496 static LogicalResult verifyAtomicUpdateOp(AtomicUpdateOp op) { 1497 if (op.memory_order()) { 1498 StringRef memoryOrder = op.memory_order().getValue(); 1499 if (memoryOrder.equals("acq_rel") || memoryOrder.equals("acquire")) 1500 return op.emitError( 1501 "memory-order must not be acq_rel or acquire for atomic updates"); 1502 } 1503 return success(); 1504 } 1505 1506 #define GET_OP_CLASSES 1507 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 1508