1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the OpenMP dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 15 #include "mlir/IR/Attributes.h" 16 #include "mlir/IR/DialectImplementation.h" 17 #include "mlir/IR/OpImplementation.h" 18 #include "mlir/IR/OperationSupport.h" 19 20 #include "llvm/ADT/BitVector.h" 21 #include "llvm/ADT/SmallString.h" 22 #include "llvm/ADT/StringExtras.h" 23 #include "llvm/ADT/StringRef.h" 24 #include "llvm/ADT/StringSwitch.h" 25 #include "llvm/ADT/TypeSwitch.h" 26 #include <cstddef> 27 28 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc" 29 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc" 30 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc" 31 32 using namespace mlir; 33 using namespace mlir::omp; 34 35 namespace { 36 /// Model for pointer-like types that already provide a `getElementType` method. 37 template <typename T> 38 struct PointerLikeModel 39 : public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> { 40 Type getElementType(Type pointer) const { 41 return pointer.cast<T>().getElementType(); 42 } 43 }; 44 } // namespace 45 46 void OpenMPDialect::initialize() { 47 addOperations< 48 #define GET_OP_LIST 49 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 50 >(); 51 addAttributes< 52 #define GET_ATTRDEF_LIST 53 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" 54 >(); 55 56 LLVM::LLVMPointerType::attachInterface< 57 PointerLikeModel<LLVM::LLVMPointerType>>(*getContext()); 58 MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext()); 59 } 60 61 //===----------------------------------------------------------------------===// 62 // ParallelOp 63 //===----------------------------------------------------------------------===// 64 65 void ParallelOp::build(OpBuilder &builder, OperationState &state, 66 ArrayRef<NamedAttribute> attributes) { 67 ParallelOp::build( 68 builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr, 69 /*default_val=*/nullptr, /*private_vars=*/ValueRange(), 70 /*firstprivate_vars=*/ValueRange(), /*shared_vars=*/ValueRange(), 71 /*copyin_vars=*/ValueRange(), /*allocate_vars=*/ValueRange(), 72 /*allocators_vars=*/ValueRange(), /*proc_bind_val=*/nullptr); 73 state.addAttributes(attributes); 74 } 75 76 //===----------------------------------------------------------------------===// 77 // Parser and printer for Operand and type list 78 //===----------------------------------------------------------------------===// 79 80 /// Parse a list of operands with types. 81 /// 82 /// operand-and-type-list ::= `(` ssa-id-and-type-list `)` 83 /// ssa-id-and-type-list ::= ssa-id-and-type | 84 /// ssa-id-and-type `,` ssa-id-and-type-list 85 /// ssa-id-and-type ::= ssa-id `:` type 86 static ParseResult 87 parseOperandAndTypeList(OpAsmParser &parser, 88 SmallVectorImpl<OpAsmParser::OperandType> &operands, 89 SmallVectorImpl<Type> &types) { 90 return parser.parseCommaSeparatedList( 91 OpAsmParser::Delimiter::Paren, [&]() -> ParseResult { 92 OpAsmParser::OperandType operand; 93 Type type; 94 if (parser.parseOperand(operand) || parser.parseColonType(type)) 95 return failure(); 96 operands.push_back(operand); 97 types.push_back(type); 98 return success(); 99 }); 100 } 101 102 /// Print an operand and type list with parentheses 103 static void printOperandAndTypeList(OpAsmPrinter &p, OperandRange operands) { 104 p << "("; 105 llvm::interleaveComma( 106 operands, p, [&](const Value &v) { p << v << " : " << v.getType(); }); 107 p << ") "; 108 } 109 110 /// Print data variables corresponding to a data-sharing clause `name` 111 static void printDataVars(OpAsmPrinter &p, OperandRange operands, 112 StringRef name) { 113 if (!operands.empty()) { 114 p << name; 115 printOperandAndTypeList(p, operands); 116 } 117 } 118 119 //===----------------------------------------------------------------------===// 120 // Parser and printer for Allocate Clause 121 //===----------------------------------------------------------------------===// 122 123 /// Parse an allocate clause with allocators and a list of operands with types. 124 /// 125 /// allocate ::= `allocate` `(` allocate-operand-list `)` 126 /// allocate-operand-list :: = allocate-operand | 127 /// allocator-operand `,` allocate-operand-list 128 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type 129 /// ssa-id-and-type ::= ssa-id `:` type 130 static ParseResult parseAllocateAndAllocator( 131 OpAsmParser &parser, 132 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate, 133 SmallVectorImpl<Type> &typesAllocate, 134 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator, 135 SmallVectorImpl<Type> &typesAllocator) { 136 137 return parser.parseCommaSeparatedList( 138 OpAsmParser::Delimiter::Paren, [&]() -> ParseResult { 139 OpAsmParser::OperandType operand; 140 Type type; 141 if (parser.parseOperand(operand) || parser.parseColonType(type)) 142 return failure(); 143 operandsAllocator.push_back(operand); 144 typesAllocator.push_back(type); 145 if (parser.parseArrow()) 146 return failure(); 147 if (parser.parseOperand(operand) || parser.parseColonType(type)) 148 return failure(); 149 150 operandsAllocate.push_back(operand); 151 typesAllocate.push_back(type); 152 return success(); 153 }); 154 } 155 156 /// Print allocate clause 157 static void printAllocateAndAllocator(OpAsmPrinter &p, 158 OperandRange varsAllocate, 159 OperandRange varsAllocator) { 160 p << "allocate("; 161 for (unsigned i = 0; i < varsAllocate.size(); ++i) { 162 std::string separator = i == varsAllocate.size() - 1 ? ") " : ", "; 163 p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> "; 164 p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator; 165 } 166 } 167 168 LogicalResult ParallelOp::verify() { 169 if (allocate_vars().size() != allocators_vars().size()) 170 return emitError( 171 "expected equal sizes for allocate and allocator variables"); 172 return success(); 173 } 174 175 void ParallelOp::print(OpAsmPrinter &p) { 176 p << " "; 177 if (auto ifCond = if_expr_var()) 178 p << "if(" << ifCond << " : " << ifCond.getType() << ") "; 179 180 if (auto threads = num_threads_var()) 181 p << "num_threads(" << threads << " : " << threads.getType() << ") "; 182 183 printDataVars(p, private_vars(), "private"); 184 printDataVars(p, firstprivate_vars(), "firstprivate"); 185 printDataVars(p, shared_vars(), "shared"); 186 printDataVars(p, copyin_vars(), "copyin"); 187 188 if (!allocate_vars().empty()) 189 printAllocateAndAllocator(p, allocate_vars(), allocators_vars()); 190 191 if (auto def = default_val()) 192 p << "default(" << stringifyClauseDefault(*def).drop_front(3) << ") "; 193 194 if (auto bind = proc_bind_val()) 195 p << "proc_bind(" << stringifyClauseProcBindKind(*bind) << ") "; 196 197 p << ' '; 198 p.printRegion(getRegion()); 199 } 200 201 void TargetOp::print(OpAsmPrinter &p) { 202 p << " "; 203 if (auto ifCond = if_expr()) 204 p << "if(" << ifCond << " : " << ifCond.getType() << ") "; 205 206 if (auto device = this->device()) 207 p << "device(" << device << " : " << device.getType() << ") "; 208 209 if (auto threads = thread_limit()) 210 p << "thread_limit(" << threads << " : " << threads.getType() << ") "; 211 212 if (nowait()) 213 p << "nowait "; 214 215 p.printRegion(getRegion()); 216 } 217 218 //===----------------------------------------------------------------------===// 219 // Parser and printer for Linear Clause 220 //===----------------------------------------------------------------------===// 221 222 /// linear ::= `linear` `(` linear-list `)` 223 /// linear-list := linear-val | linear-val linear-list 224 /// linear-val := ssa-id-and-type `=` ssa-id-and-type 225 static ParseResult 226 parseLinearClause(OpAsmParser &parser, 227 SmallVectorImpl<OpAsmParser::OperandType> &vars, 228 SmallVectorImpl<Type> &types, 229 SmallVectorImpl<OpAsmParser::OperandType> &stepVars) { 230 if (parser.parseLParen()) 231 return failure(); 232 233 do { 234 OpAsmParser::OperandType var; 235 Type type; 236 OpAsmParser::OperandType stepVar; 237 if (parser.parseOperand(var) || parser.parseEqual() || 238 parser.parseOperand(stepVar) || parser.parseColonType(type)) 239 return failure(); 240 241 vars.push_back(var); 242 types.push_back(type); 243 stepVars.push_back(stepVar); 244 } while (succeeded(parser.parseOptionalComma())); 245 246 if (parser.parseRParen()) 247 return failure(); 248 249 return success(); 250 } 251 252 /// Print Linear Clause 253 static void printLinearClause(OpAsmPrinter &p, OperandRange linearVars, 254 OperandRange linearStepVars) { 255 size_t linearVarsSize = linearVars.size(); 256 p << "linear("; 257 for (unsigned i = 0; i < linearVarsSize; ++i) { 258 std::string separator = i == linearVarsSize - 1 ? ") " : ", "; 259 p << linearVars[i]; 260 if (linearStepVars.size() > i) 261 p << " = " << linearStepVars[i]; 262 p << " : " << linearVars[i].getType() << separator; 263 } 264 } 265 266 //===----------------------------------------------------------------------===// 267 // Parser and printer for Schedule Clause 268 //===----------------------------------------------------------------------===// 269 270 static ParseResult 271 verifyScheduleModifiers(OpAsmParser &parser, 272 SmallVectorImpl<SmallString<12>> &modifiers) { 273 if (modifiers.size() > 2) 274 return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)"; 275 for (const auto &mod : modifiers) { 276 // Translate the string. If it has no value, then it was not a valid 277 // modifier! 278 auto symbol = symbolizeScheduleModifier(mod); 279 if (!symbol.hasValue()) 280 return parser.emitError(parser.getNameLoc()) 281 << " unknown modifier type: " << mod; 282 } 283 284 // If we have one modifier that is "simd", then stick a "none" modiifer in 285 // index 0. 286 if (modifiers.size() == 1) { 287 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) { 288 modifiers.push_back(modifiers[0]); 289 modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none); 290 } 291 } else if (modifiers.size() == 2) { 292 // If there are two modifier: 293 // First modifier should not be simd, second one should be simd 294 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd || 295 symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd) 296 return parser.emitError(parser.getNameLoc()) 297 << " incorrect modifier order"; 298 } 299 return success(); 300 } 301 302 /// schedule ::= `schedule` `(` sched-list `)` 303 /// sched-list ::= sched-val | sched-val sched-list | 304 /// sched-val `,` sched-modifier 305 /// sched-val ::= sched-with-chunk | sched-wo-chunk 306 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)? 307 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided` 308 /// sched-wo-chunk ::= `auto` | `runtime` 309 /// sched-modifier ::= sched-mod-val | sched-mod-val `,` sched-mod-val 310 /// sched-mod-val ::= `monotonic` | `nonmonotonic` | `simd` | `none` 311 static ParseResult 312 parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule, 313 SmallVectorImpl<SmallString<12>> &modifiers, 314 Optional<OpAsmParser::OperandType> &chunkSize, 315 Type &chunkType) { 316 if (parser.parseLParen()) 317 return failure(); 318 319 StringRef keyword; 320 if (parser.parseKeyword(&keyword)) 321 return failure(); 322 323 schedule = keyword; 324 if (keyword == "static" || keyword == "dynamic" || keyword == "guided") { 325 if (succeeded(parser.parseOptionalEqual())) { 326 chunkSize = OpAsmParser::OperandType{}; 327 if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType)) 328 return failure(); 329 } else { 330 chunkSize = llvm::NoneType::None; 331 } 332 } else if (keyword == "auto" || keyword == "runtime") { 333 chunkSize = llvm::NoneType::None; 334 } else { 335 return parser.emitError(parser.getNameLoc()) << " expected schedule kind"; 336 } 337 338 // If there is a comma, we have one or more modifiers.. 339 while (succeeded(parser.parseOptionalComma())) { 340 StringRef mod; 341 if (parser.parseKeyword(&mod)) 342 return failure(); 343 modifiers.push_back(mod); 344 } 345 346 if (parser.parseRParen()) 347 return failure(); 348 349 if (verifyScheduleModifiers(parser, modifiers)) 350 return failure(); 351 352 return success(); 353 } 354 355 /// Print schedule clause 356 static void printScheduleClause(OpAsmPrinter &p, ClauseScheduleKind sched, 357 Optional<ScheduleModifier> modifier, bool simd, 358 Value scheduleChunkVar) { 359 p << "schedule(" << stringifyClauseScheduleKind(sched).lower(); 360 if (scheduleChunkVar) 361 p << " = " << scheduleChunkVar << " : " << scheduleChunkVar.getType(); 362 if (modifier) 363 p << ", " << stringifyScheduleModifier(*modifier); 364 if (simd) 365 p << ", simd"; 366 p << ") "; 367 } 368 369 //===----------------------------------------------------------------------===// 370 // Parser, printer and verifier for ReductionVarList 371 //===----------------------------------------------------------------------===// 372 373 /// reduction ::= `reduction` `(` reduction-entry-list `)` 374 /// reduction-entry-list ::= reduction-entry 375 /// | reduction-entry-list `,` reduction-entry 376 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type 377 static ParseResult 378 parseReductionVarList(OpAsmParser &parser, 379 SmallVectorImpl<SymbolRefAttr> &symbols, 380 SmallVectorImpl<OpAsmParser::OperandType> &operands, 381 SmallVectorImpl<Type> &types) { 382 if (failed(parser.parseLParen())) 383 return failure(); 384 385 do { 386 if (parser.parseAttribute(symbols.emplace_back()) || parser.parseArrow() || 387 parser.parseOperand(operands.emplace_back()) || 388 parser.parseColonType(types.emplace_back())) 389 return failure(); 390 } while (succeeded(parser.parseOptionalComma())); 391 return parser.parseRParen(); 392 } 393 394 /// Print Reduction clause 395 static void printReductionVarList(OpAsmPrinter &p, 396 Optional<ArrayAttr> reductions, 397 OperandRange reductionVars) { 398 p << "reduction("; 399 for (unsigned i = 0, e = reductions->size(); i < e; ++i) { 400 if (i != 0) 401 p << ", "; 402 p << (*reductions)[i] << " -> " << reductionVars[i] << " : " 403 << reductionVars[i].getType(); 404 } 405 p << ") "; 406 } 407 408 /// Verifies Reduction Clause 409 static LogicalResult verifyReductionVarList(Operation *op, 410 Optional<ArrayAttr> reductions, 411 OperandRange reductionVars) { 412 if (!reductionVars.empty()) { 413 if (!reductions || reductions->size() != reductionVars.size()) 414 return op->emitOpError() 415 << "expected as many reduction symbol references " 416 "as reduction variables"; 417 } else { 418 if (reductions) 419 return op->emitOpError() << "unexpected reduction symbol references"; 420 return success(); 421 } 422 423 DenseSet<Value> accumulators; 424 for (auto args : llvm::zip(reductionVars, *reductions)) { 425 Value accum = std::get<0>(args); 426 427 if (!accumulators.insert(accum).second) 428 return op->emitOpError() << "accumulator variable used more than once"; 429 430 Type varType = accum.getType().cast<PointerLikeType>(); 431 auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>(); 432 auto decl = 433 SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef); 434 if (!decl) 435 return op->emitOpError() << "expected symbol reference " << symbolRef 436 << " to point to a reduction declaration"; 437 438 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType) 439 return op->emitOpError() 440 << "expected accumulator (" << varType 441 << ") to be the same type as reduction declaration (" 442 << decl.getAccumulatorType() << ")"; 443 } 444 445 return success(); 446 } 447 448 //===----------------------------------------------------------------------===// 449 // Parser, printer and verifier for Synchronization Hint (2.17.12) 450 //===----------------------------------------------------------------------===// 451 452 /// Parses a Synchronization Hint clause. The value of hint is an integer 453 /// which is a combination of different hints from `omp_sync_hint_t`. 454 /// 455 /// hint-clause = `hint` `(` hint-value `)` 456 static ParseResult parseSynchronizationHint(OpAsmParser &parser, 457 IntegerAttr &hintAttr, 458 bool parseKeyword = true) { 459 if (parseKeyword && failed(parser.parseOptionalKeyword("hint"))) { 460 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0); 461 return success(); 462 } 463 464 if (failed(parser.parseLParen())) 465 return failure(); 466 StringRef hintKeyword; 467 int64_t hint = 0; 468 do { 469 if (failed(parser.parseKeyword(&hintKeyword))) 470 return failure(); 471 if (hintKeyword == "uncontended") 472 hint |= 1; 473 else if (hintKeyword == "contended") 474 hint |= 2; 475 else if (hintKeyword == "nonspeculative") 476 hint |= 4; 477 else if (hintKeyword == "speculative") 478 hint |= 8; 479 else 480 return parser.emitError(parser.getCurrentLocation()) 481 << hintKeyword << " is not a valid hint"; 482 } while (succeeded(parser.parseOptionalComma())); 483 if (failed(parser.parseRParen())) 484 return failure(); 485 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint); 486 return success(); 487 } 488 489 /// Prints a Synchronization Hint clause 490 static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, 491 IntegerAttr hintAttr) { 492 int64_t hint = hintAttr.getInt(); 493 494 if (hint == 0) 495 return; 496 497 // Helper function to get n-th bit from the right end of `value` 498 auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; 499 500 bool uncontended = bitn(hint, 0); 501 bool contended = bitn(hint, 1); 502 bool nonspeculative = bitn(hint, 2); 503 bool speculative = bitn(hint, 3); 504 505 SmallVector<StringRef> hints; 506 if (uncontended) 507 hints.push_back("uncontended"); 508 if (contended) 509 hints.push_back("contended"); 510 if (nonspeculative) 511 hints.push_back("nonspeculative"); 512 if (speculative) 513 hints.push_back("speculative"); 514 515 p << "hint("; 516 llvm::interleaveComma(hints, p); 517 p << ") "; 518 } 519 520 /// Verifies a synchronization hint clause 521 static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) { 522 523 // Helper function to get n-th bit from the right end of `value` 524 auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; 525 526 bool uncontended = bitn(hint, 0); 527 bool contended = bitn(hint, 1); 528 bool nonspeculative = bitn(hint, 2); 529 bool speculative = bitn(hint, 3); 530 531 if (uncontended && contended) 532 return op->emitOpError() << "the hints omp_sync_hint_uncontended and " 533 "omp_sync_hint_contended cannot be combined"; 534 if (nonspeculative && speculative) 535 return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and " 536 "omp_sync_hint_speculative cannot be combined."; 537 return success(); 538 } 539 540 enum ClauseType { 541 ifClause, 542 numThreadsClause, 543 deviceClause, 544 threadLimitClause, 545 privateClause, 546 firstprivateClause, 547 lastprivateClause, 548 sharedClause, 549 copyinClause, 550 allocateClause, 551 defaultClause, 552 procBindClause, 553 reductionClause, 554 nowaitClause, 555 linearClause, 556 scheduleClause, 557 collapseClause, 558 orderClause, 559 orderedClause, 560 memoryOrderClause, 561 hintClause, 562 COUNT 563 }; 564 565 //===----------------------------------------------------------------------===// 566 // Parser for Clause List 567 //===----------------------------------------------------------------------===// 568 569 /// Parse a clause attribute `(` $value `)`. 570 template <typename ClauseAttr> 571 static ParseResult parseClauseAttr(AsmParser &parser, OperationState &state, 572 StringRef attrName, StringRef name) { 573 using ClauseT = decltype(std::declval<ClauseAttr>().getValue()); 574 StringRef enumStr; 575 SMLoc loc = parser.getCurrentLocation(); 576 if (parser.parseLParen() || parser.parseKeyword(&enumStr) || 577 parser.parseRParen()) 578 return failure(); 579 if (Optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) { 580 auto attr = ClauseAttr::get(parser.getContext(), *enumValue); 581 state.addAttribute(attrName, attr); 582 return success(); 583 } 584 return parser.emitError(loc, "invalid ") << name << " kind"; 585 } 586 587 /// Parse a list of clauses. The clauses can appear in any order, but their 588 /// operand segment indices are in the same order that they are passed in the 589 /// `clauses` list. The operand segments are added over the prevSegments 590 591 /// clause-list ::= clause clause-list | empty 592 /// clause ::= if | num-threads | private | firstprivate | lastprivate | 593 /// shared | copyin | allocate | default | proc-bind | reduction | 594 /// nowait | linear | schedule | collapse | order | ordered | 595 /// inclusive 596 /// if ::= `if` `(` ssa-id-and-type `)` 597 /// num-threads ::= `num_threads` `(` ssa-id-and-type `)` 598 /// private ::= `private` operand-and-type-list 599 /// firstprivate ::= `firstprivate` operand-and-type-list 600 /// lastprivate ::= `lastprivate` operand-and-type-list 601 /// shared ::= `shared` operand-and-type-list 602 /// copyin ::= `copyin` operand-and-type-list 603 /// allocate ::= `allocate` `(` allocate-operand-list `)` 604 /// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`) 605 /// proc-bind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)` 606 /// reduction ::= `reduction` `(` reduction-entry-list `)` 607 /// nowait ::= `nowait` 608 /// linear ::= `linear` `(` linear-list `)` 609 /// schedule ::= `schedule` `(` sched-list `)` 610 /// collapse ::= `collapse` `(` ssa-id-and-type `)` 611 /// order ::= `order` `(` `concurrent` `)` 612 /// ordered ::= `ordered` `(` ssa-id-and-type `)` 613 /// inclusive ::= `inclusive` 614 /// 615 /// Note that each clause can only appear once in the clase-list. 616 static ParseResult parseClauses(OpAsmParser &parser, OperationState &result, 617 SmallVectorImpl<ClauseType> &clauses, 618 SmallVectorImpl<int> &segments) { 619 620 // Check done[clause] to see if it has been parsed already 621 BitVector done(ClauseType::COUNT, false); 622 623 // See pos[clause] to get position of clause in operand segments 624 SmallVector<int> pos(ClauseType::COUNT, -1); 625 626 // Stores the last parsed clause keyword 627 StringRef clauseKeyword; 628 StringRef opName = result.name.getStringRef(); 629 630 // Containers for storing operands, types and attributes for various clauses 631 std::pair<OpAsmParser::OperandType, Type> ifCond; 632 std::pair<OpAsmParser::OperandType, Type> numThreads; 633 std::pair<OpAsmParser::OperandType, Type> device; 634 std::pair<OpAsmParser::OperandType, Type> threadLimit; 635 636 SmallVector<OpAsmParser::OperandType> privates, firstprivates, lastprivates, 637 shareds, copyins; 638 SmallVector<Type> privateTypes, firstprivateTypes, lastprivateTypes, 639 sharedTypes, copyinTypes; 640 641 SmallVector<OpAsmParser::OperandType> allocates, allocators; 642 SmallVector<Type> allocateTypes, allocatorTypes; 643 644 SmallVector<SymbolRefAttr> reductionSymbols; 645 SmallVector<OpAsmParser::OperandType> reductionVars; 646 SmallVector<Type> reductionVarTypes; 647 648 SmallVector<OpAsmParser::OperandType> linears; 649 SmallVector<Type> linearTypes; 650 SmallVector<OpAsmParser::OperandType> linearSteps; 651 652 SmallString<8> schedule; 653 SmallVector<SmallString<12>> modifiers; 654 Optional<OpAsmParser::OperandType> scheduleChunkSize; 655 Type scheduleChunkType; 656 657 // Compute the position of clauses in operand segments 658 int currPos = 0; 659 for (ClauseType clause : clauses) { 660 661 // Skip the following clauses - they do not take any position in operand 662 // segments 663 if (clause == defaultClause || clause == procBindClause || 664 clause == nowaitClause || clause == collapseClause || 665 clause == orderClause || clause == orderedClause) 666 continue; 667 668 pos[clause] = currPos++; 669 670 // For the following clauses, two positions are reserved in the operand 671 // segments 672 if (clause == allocateClause || clause == linearClause) 673 currPos++; 674 } 675 676 SmallVector<int> clauseSegments(currPos); 677 678 // Helper function to check if a clause is allowed/repeated or not 679 auto checkAllowed = [&](ClauseType clause) -> ParseResult { 680 if (!llvm::is_contained(clauses, clause)) 681 return parser.emitError(parser.getCurrentLocation()) 682 << clauseKeyword << " is not a valid clause for the " << opName 683 << " operation"; 684 if (done[clause]) 685 return parser.emitError(parser.getCurrentLocation()) 686 << "at most one " << clauseKeyword << " clause can appear on the " 687 << opName << " operation"; 688 done[clause] = true; 689 return success(); 690 }; 691 692 while (succeeded(parser.parseOptionalKeyword(&clauseKeyword))) { 693 if (clauseKeyword == "if") { 694 if (checkAllowed(ifClause) || parser.parseLParen() || 695 parser.parseOperand(ifCond.first) || 696 parser.parseColonType(ifCond.second) || parser.parseRParen()) 697 return failure(); 698 clauseSegments[pos[ifClause]] = 1; 699 } else if (clauseKeyword == "num_threads") { 700 if (checkAllowed(numThreadsClause) || parser.parseLParen() || 701 parser.parseOperand(numThreads.first) || 702 parser.parseColonType(numThreads.second) || parser.parseRParen()) 703 return failure(); 704 clauseSegments[pos[numThreadsClause]] = 1; 705 } else if (clauseKeyword == "device") { 706 if (checkAllowed(deviceClause) || parser.parseLParen() || 707 parser.parseOperand(device.first) || 708 parser.parseColonType(device.second) || parser.parseRParen()) 709 return failure(); 710 clauseSegments[pos[deviceClause]] = 1; 711 } else if (clauseKeyword == "thread_limit") { 712 if (checkAllowed(threadLimitClause) || parser.parseLParen() || 713 parser.parseOperand(threadLimit.first) || 714 parser.parseColonType(threadLimit.second) || parser.parseRParen()) 715 return failure(); 716 clauseSegments[pos[threadLimitClause]] = 1; 717 } else if (clauseKeyword == "private") { 718 if (checkAllowed(privateClause) || 719 parseOperandAndTypeList(parser, privates, privateTypes)) 720 return failure(); 721 clauseSegments[pos[privateClause]] = privates.size(); 722 } else if (clauseKeyword == "firstprivate") { 723 if (checkAllowed(firstprivateClause) || 724 parseOperandAndTypeList(parser, firstprivates, firstprivateTypes)) 725 return failure(); 726 clauseSegments[pos[firstprivateClause]] = firstprivates.size(); 727 } else if (clauseKeyword == "lastprivate") { 728 if (checkAllowed(lastprivateClause) || 729 parseOperandAndTypeList(parser, lastprivates, lastprivateTypes)) 730 return failure(); 731 clauseSegments[pos[lastprivateClause]] = lastprivates.size(); 732 } else if (clauseKeyword == "shared") { 733 if (checkAllowed(sharedClause) || 734 parseOperandAndTypeList(parser, shareds, sharedTypes)) 735 return failure(); 736 clauseSegments[pos[sharedClause]] = shareds.size(); 737 } else if (clauseKeyword == "copyin") { 738 if (checkAllowed(copyinClause) || 739 parseOperandAndTypeList(parser, copyins, copyinTypes)) 740 return failure(); 741 clauseSegments[pos[copyinClause]] = copyins.size(); 742 } else if (clauseKeyword == "allocate") { 743 if (checkAllowed(allocateClause) || 744 parseAllocateAndAllocator(parser, allocates, allocateTypes, 745 allocators, allocatorTypes)) 746 return failure(); 747 clauseSegments[pos[allocateClause]] = allocates.size(); 748 clauseSegments[pos[allocateClause] + 1] = allocators.size(); 749 } else if (clauseKeyword == "default") { 750 StringRef defval; 751 SMLoc loc = parser.getCurrentLocation(); 752 if (checkAllowed(defaultClause) || parser.parseLParen() || 753 parser.parseKeyword(&defval) || parser.parseRParen()) 754 return failure(); 755 // The def prefix is required for the attribute as "private" is a keyword 756 // in C++. 757 if (Optional<ClauseDefault> def = 758 symbolizeClauseDefault(("def" + defval).str())) { 759 result.addAttribute("default_val", 760 ClauseDefaultAttr::get(parser.getContext(), *def)); 761 } else { 762 return parser.emitError(loc, "invalid default clause"); 763 } 764 } else if (clauseKeyword == "proc_bind") { 765 if (checkAllowed(procBindClause) || 766 parseClauseAttr<ClauseProcBindKindAttr>(parser, result, 767 "proc_bind_val", "proc bind")) 768 return failure(); 769 } else if (clauseKeyword == "reduction") { 770 if (checkAllowed(reductionClause) || 771 parseReductionVarList(parser, reductionSymbols, reductionVars, 772 reductionVarTypes)) 773 return failure(); 774 clauseSegments[pos[reductionClause]] = reductionVars.size(); 775 } else if (clauseKeyword == "nowait") { 776 if (checkAllowed(nowaitClause)) 777 return failure(); 778 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 779 result.addAttribute("nowait", attr); 780 } else if (clauseKeyword == "linear") { 781 if (checkAllowed(linearClause) || 782 parseLinearClause(parser, linears, linearTypes, linearSteps)) 783 return failure(); 784 clauseSegments[pos[linearClause]] = linears.size(); 785 clauseSegments[pos[linearClause] + 1] = linearSteps.size(); 786 } else if (clauseKeyword == "schedule") { 787 if (checkAllowed(scheduleClause) || 788 parseScheduleClause(parser, schedule, modifiers, scheduleChunkSize, 789 scheduleChunkType)) 790 return failure(); 791 if (scheduleChunkSize) { 792 clauseSegments[pos[scheduleClause]] = 1; 793 } 794 } else if (clauseKeyword == "collapse") { 795 auto type = parser.getBuilder().getI64Type(); 796 mlir::IntegerAttr attr; 797 if (checkAllowed(collapseClause) || parser.parseLParen() || 798 parser.parseAttribute(attr, type) || parser.parseRParen()) 799 return failure(); 800 result.addAttribute("collapse_val", attr); 801 } else if (clauseKeyword == "ordered") { 802 mlir::IntegerAttr attr; 803 if (checkAllowed(orderedClause)) 804 return failure(); 805 if (succeeded(parser.parseOptionalLParen())) { 806 auto type = parser.getBuilder().getI64Type(); 807 if (parser.parseAttribute(attr, type) || parser.parseRParen()) 808 return failure(); 809 } else { 810 // Use 0 to represent no ordered parameter was specified 811 attr = parser.getBuilder().getI64IntegerAttr(0); 812 } 813 result.addAttribute("ordered_val", attr); 814 } else if (clauseKeyword == "order") { 815 if (checkAllowed(orderClause) || 816 parseClauseAttr<ClauseOrderKindAttr>(parser, result, "order_val", 817 "order")) 818 return failure(); 819 } else if (clauseKeyword == "memory_order") { 820 if (checkAllowed(memoryOrderClause) || 821 parseClauseAttr<ClauseMemoryOrderKindAttr>( 822 parser, result, "memory_order", "memory order")) 823 return failure(); 824 } else if (clauseKeyword == "hint") { 825 IntegerAttr hint; 826 if (checkAllowed(hintClause) || 827 parseSynchronizationHint(parser, hint, false)) 828 return failure(); 829 result.addAttribute("hint", hint); 830 } else { 831 return parser.emitError(parser.getNameLoc()) 832 << clauseKeyword << " is not a valid clause"; 833 } 834 } 835 836 // Add if parameter. 837 if (done[ifClause] && clauseSegments[pos[ifClause]] && 838 failed( 839 parser.resolveOperand(ifCond.first, ifCond.second, result.operands))) 840 return failure(); 841 842 // Add num_threads parameter. 843 if (done[numThreadsClause] && clauseSegments[pos[numThreadsClause]] && 844 failed(parser.resolveOperand(numThreads.first, numThreads.second, 845 result.operands))) 846 return failure(); 847 848 // Add device parameter. 849 if (done[deviceClause] && clauseSegments[pos[deviceClause]] && 850 failed( 851 parser.resolveOperand(device.first, device.second, result.operands))) 852 return failure(); 853 854 // Add thread_limit parameter. 855 if (done[threadLimitClause] && clauseSegments[pos[threadLimitClause]] && 856 failed(parser.resolveOperand(threadLimit.first, threadLimit.second, 857 result.operands))) 858 return failure(); 859 860 // Add private parameters. 861 if (done[privateClause] && clauseSegments[pos[privateClause]] && 862 failed(parser.resolveOperands(privates, privateTypes, 863 privates[0].location, result.operands))) 864 return failure(); 865 866 // Add firstprivate parameters. 867 if (done[firstprivateClause] && clauseSegments[pos[firstprivateClause]] && 868 failed(parser.resolveOperands(firstprivates, firstprivateTypes, 869 firstprivates[0].location, 870 result.operands))) 871 return failure(); 872 873 // Add lastprivate parameters. 874 if (done[lastprivateClause] && clauseSegments[pos[lastprivateClause]] && 875 failed(parser.resolveOperands(lastprivates, lastprivateTypes, 876 lastprivates[0].location, result.operands))) 877 return failure(); 878 879 // Add shared parameters. 880 if (done[sharedClause] && clauseSegments[pos[sharedClause]] && 881 failed(parser.resolveOperands(shareds, sharedTypes, shareds[0].location, 882 result.operands))) 883 return failure(); 884 885 // Add copyin parameters. 886 if (done[copyinClause] && clauseSegments[pos[copyinClause]] && 887 failed(parser.resolveOperands(copyins, copyinTypes, copyins[0].location, 888 result.operands))) 889 return failure(); 890 891 // Add allocate parameters. 892 if (done[allocateClause] && clauseSegments[pos[allocateClause]] && 893 failed(parser.resolveOperands(allocates, allocateTypes, 894 allocates[0].location, result.operands))) 895 return failure(); 896 897 // Add allocator parameters. 898 if (done[allocateClause] && clauseSegments[pos[allocateClause] + 1] && 899 failed(parser.resolveOperands(allocators, allocatorTypes, 900 allocators[0].location, result.operands))) 901 return failure(); 902 903 // Add reduction parameters and symbols 904 if (done[reductionClause] && clauseSegments[pos[reductionClause]]) { 905 if (failed(parser.resolveOperands(reductionVars, reductionVarTypes, 906 parser.getNameLoc(), result.operands))) 907 return failure(); 908 909 SmallVector<Attribute> reductions(reductionSymbols.begin(), 910 reductionSymbols.end()); 911 result.addAttribute("reductions", 912 parser.getBuilder().getArrayAttr(reductions)); 913 } 914 915 // Add linear parameters 916 if (done[linearClause] && clauseSegments[pos[linearClause]]) { 917 auto linearStepType = parser.getBuilder().getI32Type(); 918 SmallVector<Type> linearStepTypes(linearSteps.size(), linearStepType); 919 if (failed(parser.resolveOperands(linears, linearTypes, linears[0].location, 920 result.operands)) || 921 failed(parser.resolveOperands(linearSteps, linearStepTypes, 922 linearSteps[0].location, 923 result.operands))) 924 return failure(); 925 } 926 927 // Add schedule parameters 928 if (done[scheduleClause] && !schedule.empty()) { 929 schedule[0] = llvm::toUpper(schedule[0]); 930 if (Optional<ClauseScheduleKind> sched = 931 symbolizeClauseScheduleKind(schedule)) { 932 auto attr = ClauseScheduleKindAttr::get(parser.getContext(), *sched); 933 result.addAttribute("schedule_val", attr); 934 } else { 935 return parser.emitError(parser.getCurrentLocation(), 936 "invalid schedule kind"); 937 } 938 if (!modifiers.empty()) { 939 SMLoc loc = parser.getCurrentLocation(); 940 if (Optional<ScheduleModifier> mod = 941 symbolizeScheduleModifier(modifiers[0])) { 942 result.addAttribute( 943 "schedule_modifier", 944 ScheduleModifierAttr::get(parser.getContext(), *mod)); 945 } else { 946 return parser.emitError(loc, "invalid schedule modifier"); 947 } 948 // Only SIMD attribute is allowed here! 949 if (modifiers.size() > 1) { 950 assert(symbolizeScheduleModifier(modifiers[1]) == 951 ScheduleModifier::simd); 952 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 953 result.addAttribute("simd_modifier", attr); 954 } 955 } 956 if (scheduleChunkSize) 957 parser.resolveOperand(*scheduleChunkSize, scheduleChunkType, 958 result.operands); 959 } 960 961 segments.insert(segments.end(), clauseSegments.begin(), clauseSegments.end()); 962 963 return success(); 964 } 965 966 /// Parses a parallel operation. 967 /// 968 /// operation ::= `omp.parallel` clause-list 969 /// clause-list ::= clause | clause clause-list 970 /// clause ::= if | num-threads | private | firstprivate | shared | copyin | 971 /// allocate | default | proc-bind 972 /// 973 ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) { 974 SmallVector<ClauseType> clauses = { 975 ifClause, numThreadsClause, privateClause, 976 firstprivateClause, sharedClause, copyinClause, 977 allocateClause, defaultClause, procBindClause}; 978 979 SmallVector<int> segments; 980 981 if (failed(parseClauses(parser, result, clauses, segments))) 982 return failure(); 983 984 result.addAttribute("operand_segment_sizes", 985 parser.getBuilder().getI32VectorAttr(segments)); 986 987 Region *body = result.addRegion(); 988 SmallVector<OpAsmParser::OperandType> regionArgs; 989 SmallVector<Type> regionArgTypes; 990 if (parser.parseRegion(*body, regionArgs, regionArgTypes)) 991 return failure(); 992 return success(); 993 } 994 995 /// Parses a target operation. 996 /// 997 /// operation ::= `omp.target` clause-list 998 /// clause-list ::= clause | clause clause-list 999 /// clause ::= if | device | thread_limit | nowait 1000 /// 1001 ParseResult TargetOp::parse(OpAsmParser &parser, OperationState &result) { 1002 SmallVector<ClauseType> clauses = {ifClause, deviceClause, threadLimitClause, 1003 nowaitClause}; 1004 1005 SmallVector<int> segments; 1006 1007 if (failed(parseClauses(parser, result, clauses, segments))) 1008 return failure(); 1009 1010 result.addAttribute( 1011 TargetOp::AttrSizedOperandSegments::getOperandSegmentSizeAttr(), 1012 parser.getBuilder().getI32VectorAttr(segments)); 1013 1014 Region *body = result.addRegion(); 1015 SmallVector<OpAsmParser::OperandType> regionArgs; 1016 SmallVector<Type> regionArgTypes; 1017 if (parser.parseRegion(*body, regionArgs, regionArgTypes)) 1018 return failure(); 1019 return success(); 1020 } 1021 1022 //===----------------------------------------------------------------------===// 1023 // Parser, printer and verifier for SectionsOp 1024 //===----------------------------------------------------------------------===// 1025 1026 /// Parses an OpenMP Sections operation 1027 /// 1028 /// sections ::= `omp.sections` clause-list 1029 /// clause-list ::= clause clause-list | empty 1030 /// clause ::= private | firstprivate | lastprivate | reduction | allocate | 1031 /// nowait 1032 ParseResult SectionsOp::parse(OpAsmParser &parser, OperationState &result) { 1033 SmallVector<ClauseType> clauses = {privateClause, firstprivateClause, 1034 lastprivateClause, reductionClause, 1035 allocateClause, nowaitClause}; 1036 1037 SmallVector<int> segments; 1038 1039 if (failed(parseClauses(parser, result, clauses, segments))) 1040 return failure(); 1041 1042 result.addAttribute("operand_segment_sizes", 1043 parser.getBuilder().getI32VectorAttr(segments)); 1044 1045 // Now parse the body. 1046 Region *body = result.addRegion(); 1047 if (parser.parseRegion(*body)) 1048 return failure(); 1049 return success(); 1050 } 1051 1052 void SectionsOp::print(OpAsmPrinter &p) { 1053 p << " "; 1054 printDataVars(p, private_vars(), "private"); 1055 printDataVars(p, firstprivate_vars(), "firstprivate"); 1056 printDataVars(p, lastprivate_vars(), "lastprivate"); 1057 1058 if (!reduction_vars().empty()) 1059 printReductionVarList(p, reductions(), reduction_vars()); 1060 1061 if (!allocate_vars().empty()) 1062 printAllocateAndAllocator(p, allocate_vars(), allocators_vars()); 1063 1064 if (nowait()) 1065 p << "nowait"; 1066 1067 p << ' '; 1068 p.printRegion(region()); 1069 } 1070 1071 LogicalResult SectionsOp::verify() { 1072 // A list item may not appear in more than one clause on the same directive, 1073 // except that it may be specified in both firstprivate and lastprivate 1074 // clauses. 1075 for (auto var : private_vars()) { 1076 if (llvm::is_contained(firstprivate_vars(), var)) 1077 return emitOpError() 1078 << "operand used in both private and firstprivate clauses"; 1079 if (llvm::is_contained(lastprivate_vars(), var)) 1080 return emitOpError() 1081 << "operand used in both private and lastprivate clauses"; 1082 } 1083 1084 if (allocate_vars().size() != allocators_vars().size()) 1085 return emitError( 1086 "expected equal sizes for allocate and allocator variables"); 1087 1088 for (auto &inst : *region().begin()) { 1089 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) { 1090 return emitOpError() 1091 << "expected omp.section op or terminator op inside region"; 1092 } 1093 } 1094 1095 return verifyReductionVarList(*this, reductions(), reduction_vars()); 1096 } 1097 1098 /// Parses an OpenMP Workshare Loop operation 1099 /// 1100 /// wsloop ::= `omp.wsloop` loop-control clause-list 1101 /// loop-control ::= `(` ssa-id-list `)` `:` type `=` loop-bounds 1102 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` inclusive? steps 1103 /// steps := `step` `(`ssa-id-list`)` 1104 /// clause-list ::= clause clause-list | empty 1105 /// clause ::= private | firstprivate | lastprivate | linear | schedule | 1106 // collapse | nowait | ordered | order | reduction 1107 ParseResult WsLoopOp::parse(OpAsmParser &parser, OperationState &result) { 1108 // Parse an opening `(` followed by induction variables followed by `)` 1109 SmallVector<OpAsmParser::OperandType> ivs; 1110 if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1, 1111 OpAsmParser::Delimiter::Paren)) 1112 return failure(); 1113 1114 int numIVs = static_cast<int>(ivs.size()); 1115 Type loopVarType; 1116 if (parser.parseColonType(loopVarType)) 1117 return failure(); 1118 1119 // Parse loop bounds. 1120 SmallVector<OpAsmParser::OperandType> lower; 1121 if (parser.parseEqual() || 1122 parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) || 1123 parser.resolveOperands(lower, loopVarType, result.operands)) 1124 return failure(); 1125 1126 SmallVector<OpAsmParser::OperandType> upper; 1127 if (parser.parseKeyword("to") || 1128 parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) || 1129 parser.resolveOperands(upper, loopVarType, result.operands)) 1130 return failure(); 1131 1132 if (succeeded(parser.parseOptionalKeyword("inclusive"))) { 1133 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 1134 result.addAttribute("inclusive", attr); 1135 } 1136 1137 // Parse step values. 1138 SmallVector<OpAsmParser::OperandType> steps; 1139 if (parser.parseKeyword("step") || 1140 parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) || 1141 parser.resolveOperands(steps, loopVarType, result.operands)) 1142 return failure(); 1143 1144 SmallVector<ClauseType> clauses = { 1145 privateClause, firstprivateClause, lastprivateClause, linearClause, 1146 reductionClause, collapseClause, orderClause, orderedClause, 1147 nowaitClause, scheduleClause}; 1148 SmallVector<int> segments{numIVs, numIVs, numIVs}; 1149 if (failed(parseClauses(parser, result, clauses, segments))) 1150 return failure(); 1151 1152 result.addAttribute("operand_segment_sizes", 1153 parser.getBuilder().getI32VectorAttr(segments)); 1154 1155 // Now parse the body. 1156 Region *body = result.addRegion(); 1157 SmallVector<Type> ivTypes(numIVs, loopVarType); 1158 SmallVector<OpAsmParser::OperandType> blockArgs(ivs); 1159 if (parser.parseRegion(*body, blockArgs, ivTypes)) 1160 return failure(); 1161 return success(); 1162 } 1163 1164 void WsLoopOp::print(OpAsmPrinter &p) { 1165 auto args = getRegion().front().getArguments(); 1166 p << " (" << args << ") : " << args[0].getType() << " = (" << lowerBound() 1167 << ") to (" << upperBound() << ") "; 1168 if (inclusive()) { 1169 p << "inclusive "; 1170 } 1171 p << "step (" << step() << ") "; 1172 1173 printDataVars(p, private_vars(), "private"); 1174 printDataVars(p, firstprivate_vars(), "firstprivate"); 1175 printDataVars(p, lastprivate_vars(), "lastprivate"); 1176 1177 if (!linear_vars().empty()) 1178 printLinearClause(p, linear_vars(), linear_step_vars()); 1179 1180 if (auto sched = schedule_val()) 1181 printScheduleClause(p, sched.getValue(), schedule_modifier(), 1182 simd_modifier(), schedule_chunk_var()); 1183 1184 if (auto collapse = collapse_val()) 1185 p << "collapse(" << collapse << ") "; 1186 1187 if (nowait()) 1188 p << "nowait "; 1189 1190 if (auto ordered = ordered_val()) 1191 p << "ordered(" << ordered << ") "; 1192 1193 if (auto order = order_val()) 1194 p << "order(" << stringifyClauseOrderKind(*order) << ") "; 1195 1196 if (!reduction_vars().empty()) 1197 printReductionVarList(p, reductions(), reduction_vars()); 1198 1199 p << ' '; 1200 p.printRegion(region(), /*printEntryBlockArgs=*/false); 1201 } 1202 1203 //===----------------------------------------------------------------------===// 1204 // ReductionOp 1205 //===----------------------------------------------------------------------===// 1206 1207 static ParseResult parseAtomicReductionRegion(OpAsmParser &parser, 1208 Region ®ion) { 1209 if (parser.parseOptionalKeyword("atomic")) 1210 return success(); 1211 return parser.parseRegion(region); 1212 } 1213 1214 static void printAtomicReductionRegion(OpAsmPrinter &printer, 1215 ReductionDeclareOp op, Region ®ion) { 1216 if (region.empty()) 1217 return; 1218 printer << "atomic "; 1219 printer.printRegion(region); 1220 } 1221 1222 LogicalResult ReductionDeclareOp::verify() { 1223 if (initializerRegion().empty()) 1224 return emitOpError() << "expects non-empty initializer region"; 1225 Block &initializerEntryBlock = initializerRegion().front(); 1226 if (initializerEntryBlock.getNumArguments() != 1 || 1227 initializerEntryBlock.getArgument(0).getType() != type()) { 1228 return emitOpError() << "expects initializer region with one argument " 1229 "of the reduction type"; 1230 } 1231 1232 for (YieldOp yieldOp : initializerRegion().getOps<YieldOp>()) { 1233 if (yieldOp.results().size() != 1 || 1234 yieldOp.results().getTypes()[0] != type()) 1235 return emitOpError() << "expects initializer region to yield a value " 1236 "of the reduction type"; 1237 } 1238 1239 if (reductionRegion().empty()) 1240 return emitOpError() << "expects non-empty reduction region"; 1241 Block &reductionEntryBlock = reductionRegion().front(); 1242 if (reductionEntryBlock.getNumArguments() != 2 || 1243 reductionEntryBlock.getArgumentTypes()[0] != 1244 reductionEntryBlock.getArgumentTypes()[1] || 1245 reductionEntryBlock.getArgumentTypes()[0] != type()) 1246 return emitOpError() << "expects reduction region with two arguments of " 1247 "the reduction type"; 1248 for (YieldOp yieldOp : reductionRegion().getOps<YieldOp>()) { 1249 if (yieldOp.results().size() != 1 || 1250 yieldOp.results().getTypes()[0] != type()) 1251 return emitOpError() << "expects reduction region to yield a value " 1252 "of the reduction type"; 1253 } 1254 1255 if (atomicReductionRegion().empty()) 1256 return success(); 1257 1258 Block &atomicReductionEntryBlock = atomicReductionRegion().front(); 1259 if (atomicReductionEntryBlock.getNumArguments() != 2 || 1260 atomicReductionEntryBlock.getArgumentTypes()[0] != 1261 atomicReductionEntryBlock.getArgumentTypes()[1]) 1262 return emitOpError() << "expects atomic reduction region with two " 1263 "arguments of the same type"; 1264 auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0] 1265 .dyn_cast<PointerLikeType>(); 1266 if (!ptrType || ptrType.getElementType() != type()) 1267 return emitOpError() << "expects atomic reduction region arguments to " 1268 "be accumulators containing the reduction type"; 1269 return success(); 1270 } 1271 1272 LogicalResult ReductionOp::verify() { 1273 // TODO: generalize this to an op interface when there is more than one op 1274 // that supports reductions. 1275 auto container = (*this)->getParentOfType<WsLoopOp>(); 1276 for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i) 1277 if (container.reduction_vars()[i] == accumulator()) 1278 return success(); 1279 1280 return emitOpError() << "the accumulator is not used by the parent"; 1281 } 1282 1283 //===----------------------------------------------------------------------===// 1284 // WsLoopOp 1285 //===----------------------------------------------------------------------===// 1286 1287 void WsLoopOp::build(OpBuilder &builder, OperationState &state, 1288 ValueRange lowerBound, ValueRange upperBound, 1289 ValueRange step, ArrayRef<NamedAttribute> attributes) { 1290 build(builder, state, TypeRange(), lowerBound, upperBound, step, 1291 /*privateVars=*/ValueRange(), 1292 /*firstprivateVars=*/ValueRange(), /*lastprivate_vars=*/ValueRange(), 1293 /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(), 1294 /*reduction_vars=*/ValueRange(), /*schedule_val=*/nullptr, 1295 /*schedule_chunk_var=*/nullptr, /*collapse_val=*/nullptr, 1296 /*nowait=*/nullptr, /*ordered_val=*/nullptr, /*order_val=*/nullptr, 1297 /*inclusive=*/nullptr, /*buildBody=*/false); 1298 state.addAttributes(attributes); 1299 } 1300 1301 void WsLoopOp::build(OpBuilder &, OperationState &state, TypeRange resultTypes, 1302 ValueRange operands, ArrayRef<NamedAttribute> attributes) { 1303 state.addOperands(operands); 1304 state.addAttributes(attributes); 1305 (void)state.addRegion(); 1306 assert(resultTypes.empty() && "mismatched number of return types"); 1307 state.addTypes(resultTypes); 1308 } 1309 1310 void WsLoopOp::build(OpBuilder &builder, OperationState &result, 1311 TypeRange typeRange, ValueRange lowerBounds, 1312 ValueRange upperBounds, ValueRange steps, 1313 ValueRange privateVars, ValueRange firstprivateVars, 1314 ValueRange lastprivateVars, ValueRange linearVars, 1315 ValueRange linearStepVars, ValueRange reductionVars, 1316 StringAttr scheduleVal, Value scheduleChunkVar, 1317 IntegerAttr collapseVal, UnitAttr nowait, 1318 IntegerAttr orderedVal, StringAttr orderVal, 1319 UnitAttr inclusive, bool buildBody) { 1320 result.addOperands(lowerBounds); 1321 result.addOperands(upperBounds); 1322 result.addOperands(steps); 1323 result.addOperands(privateVars); 1324 result.addOperands(firstprivateVars); 1325 result.addOperands(linearVars); 1326 result.addOperands(linearStepVars); 1327 if (scheduleChunkVar) 1328 result.addOperands(scheduleChunkVar); 1329 1330 if (scheduleVal) 1331 result.addAttribute("schedule_val", scheduleVal); 1332 if (collapseVal) 1333 result.addAttribute("collapse_val", collapseVal); 1334 if (nowait) 1335 result.addAttribute("nowait", nowait); 1336 if (orderedVal) 1337 result.addAttribute("ordered_val", orderedVal); 1338 if (orderVal) 1339 result.addAttribute("order", orderVal); 1340 if (inclusive) 1341 result.addAttribute("inclusive", inclusive); 1342 result.addAttribute( 1343 WsLoopOp::getOperandSegmentSizeAttr(), 1344 builder.getI32VectorAttr( 1345 {static_cast<int32_t>(lowerBounds.size()), 1346 static_cast<int32_t>(upperBounds.size()), 1347 static_cast<int32_t>(steps.size()), 1348 static_cast<int32_t>(privateVars.size()), 1349 static_cast<int32_t>(firstprivateVars.size()), 1350 static_cast<int32_t>(lastprivateVars.size()), 1351 static_cast<int32_t>(linearVars.size()), 1352 static_cast<int32_t>(linearStepVars.size()), 1353 static_cast<int32_t>(reductionVars.size()), 1354 static_cast<int32_t>(scheduleChunkVar != nullptr ? 1 : 0)})); 1355 1356 Region *bodyRegion = result.addRegion(); 1357 if (buildBody) { 1358 OpBuilder::InsertionGuard guard(builder); 1359 unsigned numIVs = steps.size(); 1360 SmallVector<Type, 8> argTypes(numIVs, steps.getType().front()); 1361 SmallVector<Location, 8> argLocs(numIVs, result.location); 1362 builder.createBlock(bodyRegion, {}, argTypes, argLocs); 1363 } 1364 } 1365 1366 LogicalResult WsLoopOp::verify() { 1367 return verifyReductionVarList(*this, reductions(), reduction_vars()); 1368 } 1369 1370 //===----------------------------------------------------------------------===// 1371 // Verifier for critical construct (2.17.1) 1372 //===----------------------------------------------------------------------===// 1373 1374 LogicalResult CriticalDeclareOp::verify() { 1375 return verifySynchronizationHint(*this, hint()); 1376 } 1377 1378 LogicalResult CriticalOp::verify() { 1379 if (nameAttr()) { 1380 SymbolRefAttr symbolRef = nameAttr(); 1381 auto decl = SymbolTable::lookupNearestSymbolFrom<CriticalDeclareOp>( 1382 *this, symbolRef); 1383 if (!decl) { 1384 return emitOpError() << "expected symbol reference " << symbolRef 1385 << " to point to a critical declaration"; 1386 } 1387 } 1388 1389 return success(); 1390 } 1391 1392 //===----------------------------------------------------------------------===// 1393 // Verifier for ordered construct 1394 //===----------------------------------------------------------------------===// 1395 1396 LogicalResult OrderedOp::verify() { 1397 auto container = (*this)->getParentOfType<WsLoopOp>(); 1398 if (!container || !container.ordered_valAttr() || 1399 container.ordered_valAttr().getInt() == 0) 1400 return emitOpError() << "ordered depend directive must be closely " 1401 << "nested inside a worksharing-loop with ordered " 1402 << "clause with parameter present"; 1403 1404 if (container.ordered_valAttr().getInt() != 1405 (int64_t)num_loops_val().getValue()) 1406 return emitOpError() << "number of variables in depend clause does not " 1407 << "match number of iteration variables in the " 1408 << "doacross loop"; 1409 1410 return success(); 1411 } 1412 1413 LogicalResult OrderedRegionOp::verify() { 1414 // TODO: The code generation for ordered simd directive is not supported yet. 1415 if (simd()) 1416 return failure(); 1417 1418 if (auto container = (*this)->getParentOfType<WsLoopOp>()) { 1419 if (!container.ordered_valAttr() || 1420 container.ordered_valAttr().getInt() != 0) 1421 return emitOpError() << "ordered region must be closely nested inside " 1422 << "a worksharing-loop region with an ordered " 1423 << "clause without parameter present"; 1424 } 1425 1426 return success(); 1427 } 1428 1429 //===----------------------------------------------------------------------===// 1430 // AtomicReadOp 1431 //===----------------------------------------------------------------------===// 1432 1433 /// Parser for AtomicReadOp 1434 /// 1435 /// operation ::= `omp.atomic.read` atomic-clause-list address `->` result-type 1436 /// address ::= operand `:` type 1437 ParseResult AtomicReadOp::parse(OpAsmParser &parser, OperationState &result) { 1438 OpAsmParser::OperandType x, v; 1439 Type addressType; 1440 SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause}; 1441 SmallVector<int> segments; 1442 1443 if (parser.parseOperand(v) || parser.parseEqual() || parser.parseOperand(x) || 1444 parseClauses(parser, result, clauses, segments) || 1445 parser.parseColonType(addressType) || 1446 parser.resolveOperand(x, addressType, result.operands) || 1447 parser.resolveOperand(v, addressType, result.operands)) 1448 return failure(); 1449 1450 return success(); 1451 } 1452 1453 void AtomicReadOp::print(OpAsmPrinter &p) { 1454 p << " " << v() << " = " << x() << " "; 1455 if (auto mo = memory_order()) 1456 p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") "; 1457 if (hintAttr()) 1458 printSynchronizationHint(p << " ", *this, hintAttr()); 1459 p << ": " << x().getType(); 1460 } 1461 1462 /// Verifier for AtomicReadOp 1463 LogicalResult AtomicReadOp::verify() { 1464 if (auto mo = memory_order()) { 1465 if (*mo == ClauseMemoryOrderKind::acq_rel || 1466 *mo == ClauseMemoryOrderKind::release) { 1467 return emitError( 1468 "memory-order must not be acq_rel or release for atomic reads"); 1469 } 1470 } 1471 if (x() == v()) 1472 return emitError( 1473 "read and write must not be to the same location for atomic reads"); 1474 return verifySynchronizationHint(*this, hint()); 1475 } 1476 1477 //===----------------------------------------------------------------------===// 1478 // AtomicWriteOp 1479 //===----------------------------------------------------------------------===// 1480 1481 /// Parser for AtomicWriteOp 1482 /// 1483 /// operation ::= `omp.atomic.write` atomic-clause-list operands 1484 /// operands ::= address `,` value 1485 /// address ::= operand `:` type 1486 /// value ::= operand `:` type 1487 ParseResult AtomicWriteOp::parse(OpAsmParser &parser, OperationState &result) { 1488 OpAsmParser::OperandType address, value; 1489 Type addrType, valueType; 1490 SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause}; 1491 SmallVector<int> segments; 1492 1493 if (parser.parseOperand(address) || parser.parseEqual() || 1494 parser.parseOperand(value) || 1495 parseClauses(parser, result, clauses, segments) || 1496 parser.parseColonType(addrType) || parser.parseComma() || 1497 parser.parseType(valueType) || 1498 parser.resolveOperand(address, addrType, result.operands) || 1499 parser.resolveOperand(value, valueType, result.operands)) 1500 return failure(); 1501 return success(); 1502 } 1503 1504 void AtomicWriteOp::print(OpAsmPrinter &p) { 1505 p << " " << address() << " = " << value() << " "; 1506 if (auto mo = memory_order()) 1507 p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") "; 1508 if (hintAttr()) 1509 printSynchronizationHint(p, *this, hintAttr()); 1510 p << ": " << address().getType() << ", " << value().getType(); 1511 } 1512 1513 /// Verifier for AtomicWriteOp 1514 LogicalResult AtomicWriteOp::verify() { 1515 if (auto mo = memory_order()) { 1516 if (*mo == ClauseMemoryOrderKind::acq_rel || 1517 *mo == ClauseMemoryOrderKind::acquire) { 1518 return emitError( 1519 "memory-order must not be acq_rel or acquire for atomic writes"); 1520 } 1521 } 1522 return verifySynchronizationHint(*this, hint()); 1523 } 1524 1525 //===----------------------------------------------------------------------===// 1526 // AtomicUpdateOp 1527 //===----------------------------------------------------------------------===// 1528 1529 /// Parser for AtomicUpdateOp 1530 /// 1531 /// operation ::= `omp.atomic.update` atomic-clause-list ssa-id-and-type region 1532 ParseResult AtomicUpdateOp::parse(OpAsmParser &parser, OperationState &result) { 1533 SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause}; 1534 SmallVector<int> segments; 1535 OpAsmParser::OperandType x, expr; 1536 Type xType; 1537 1538 if (parseClauses(parser, result, clauses, segments) || 1539 parser.parseOperand(x) || parser.parseColon() || 1540 parser.parseType(xType) || 1541 parser.resolveOperand(x, xType, result.operands) || 1542 parser.parseRegion(*result.addRegion())) 1543 return failure(); 1544 return success(); 1545 } 1546 1547 void AtomicUpdateOp::print(OpAsmPrinter &p) { 1548 p << " "; 1549 if (auto mo = memory_order()) 1550 p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") "; 1551 if (hintAttr()) 1552 printSynchronizationHint(p, *this, hintAttr()); 1553 p << x() << " : " << x().getType(); 1554 p.printRegion(region()); 1555 } 1556 1557 /// Verifier for AtomicUpdateOp 1558 LogicalResult AtomicUpdateOp::verify() { 1559 if (auto mo = memory_order()) { 1560 if (*mo == ClauseMemoryOrderKind::acq_rel || 1561 *mo == ClauseMemoryOrderKind::acquire) { 1562 return emitError( 1563 "memory-order must not be acq_rel or acquire for atomic updates"); 1564 } 1565 } 1566 1567 if (region().getNumArguments() != 1) 1568 return emitError("the region must accept exactly one argument"); 1569 1570 if (x().getType().cast<PointerLikeType>().getElementType() != 1571 region().getArgument(0).getType()) { 1572 return emitError("the type of the operand must be a pointer type whose " 1573 "element type is the same as that of the region argument"); 1574 } 1575 1576 YieldOp yieldOp = *region().getOps<YieldOp>().begin(); 1577 1578 if (yieldOp.results().size() != 1) 1579 return emitError("only updated value must be returned"); 1580 if (yieldOp.results().front().getType() != region().getArgument(0).getType()) 1581 return emitError("input and yielded value must have the same type"); 1582 return success(); 1583 } 1584 1585 //===----------------------------------------------------------------------===// 1586 // AtomicCaptureOp 1587 //===----------------------------------------------------------------------===// 1588 1589 ParseResult AtomicCaptureOp::parse(OpAsmParser &parser, 1590 OperationState &result) { 1591 SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause}; 1592 SmallVector<int> segments; 1593 if (parseClauses(parser, result, clauses, segments) || 1594 parser.parseRegion(*result.addRegion())) 1595 return failure(); 1596 return success(); 1597 } 1598 1599 void AtomicCaptureOp::print(OpAsmPrinter &p) { 1600 if (memory_order()) 1601 p << "memory_order(" << memory_order() << ") "; 1602 if (hintAttr()) 1603 printSynchronizationHint(p, *this, hintAttr()); 1604 p.printRegion(region()); 1605 } 1606 1607 /// Verifier for AtomicCaptureOp 1608 LogicalResult AtomicCaptureOp::verify() { 1609 Block::OpListType &ops = region().front().getOperations(); 1610 if (ops.size() != 3) 1611 return emitError() 1612 << "expected three operations in omp.atomic.capture region (one " 1613 "terminator, and two atomic ops)"; 1614 auto &firstOp = ops.front(); 1615 auto &secondOp = *ops.getNextNode(firstOp); 1616 auto firstReadStmt = dyn_cast<AtomicReadOp>(firstOp); 1617 auto firstUpdateStmt = dyn_cast<AtomicUpdateOp>(firstOp); 1618 auto secondReadStmt = dyn_cast<AtomicReadOp>(secondOp); 1619 auto secondUpdateStmt = dyn_cast<AtomicUpdateOp>(secondOp); 1620 auto secondWriteStmt = dyn_cast<AtomicWriteOp>(secondOp); 1621 1622 if (!((firstUpdateStmt && secondReadStmt) || 1623 (firstReadStmt && secondUpdateStmt) || 1624 (firstReadStmt && secondWriteStmt))) 1625 return ops.front().emitError() 1626 << "invalid sequence of operations in the capture region"; 1627 if (firstUpdateStmt && secondReadStmt && 1628 firstUpdateStmt.x() != secondReadStmt.x()) 1629 return firstUpdateStmt.emitError() 1630 << "updated variable in omp.atomic.update must be captured in " 1631 "second operation"; 1632 if (firstReadStmt && secondUpdateStmt && 1633 firstReadStmt.x() != secondUpdateStmt.x()) 1634 return firstReadStmt.emitError() 1635 << "captured variable in omp.atomic.read must be updated in second " 1636 "operation"; 1637 if (firstReadStmt && secondWriteStmt && 1638 firstReadStmt.x() != secondWriteStmt.address()) 1639 return firstReadStmt.emitError() 1640 << "captured variable in omp.atomic.read must be updated in " 1641 "second operation"; 1642 return success(); 1643 } 1644 1645 #define GET_ATTRDEF_CLASSES 1646 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" 1647 1648 #define GET_OP_CLASSES 1649 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 1650