1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the OpenMP dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/IR/Attributes.h" 17 #include "mlir/IR/DialectImplementation.h" 18 #include "mlir/IR/OpImplementation.h" 19 #include "mlir/IR/OperationSupport.h" 20 21 #include "llvm/ADT/BitVector.h" 22 #include "llvm/ADT/SmallString.h" 23 #include "llvm/ADT/StringExtras.h" 24 #include "llvm/ADT/StringRef.h" 25 #include "llvm/ADT/StringSwitch.h" 26 #include "llvm/ADT/TypeSwitch.h" 27 #include <cstddef> 28 29 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc" 30 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc" 31 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc" 32 33 using namespace mlir; 34 using namespace mlir::omp; 35 36 namespace { 37 /// Model for pointer-like types that already provide a `getElementType` method. 38 template <typename T> 39 struct PointerLikeModel 40 : public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> { 41 Type getElementType(Type pointer) const { 42 return pointer.cast<T>().getElementType(); 43 } 44 }; 45 } // namespace 46 47 void OpenMPDialect::initialize() { 48 addOperations< 49 #define GET_OP_LIST 50 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 51 >(); 52 addAttributes< 53 #define GET_ATTRDEF_LIST 54 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" 55 >(); 56 57 LLVM::LLVMPointerType::attachInterface< 58 PointerLikeModel<LLVM::LLVMPointerType>>(*getContext()); 59 MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext()); 60 } 61 62 //===----------------------------------------------------------------------===// 63 // ParallelOp 64 //===----------------------------------------------------------------------===// 65 66 void ParallelOp::build(OpBuilder &builder, OperationState &state, 67 ArrayRef<NamedAttribute> attributes) { 68 ParallelOp::build( 69 builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr, 70 /*default_val=*/nullptr, /*private_vars=*/ValueRange(), 71 /*firstprivate_vars=*/ValueRange(), /*shared_vars=*/ValueRange(), 72 /*copyin_vars=*/ValueRange(), /*allocate_vars=*/ValueRange(), 73 /*allocators_vars=*/ValueRange(), /*proc_bind_val=*/nullptr); 74 state.addAttributes(attributes); 75 } 76 77 //===----------------------------------------------------------------------===// 78 // Parser and printer for Operand and type list 79 //===----------------------------------------------------------------------===// 80 81 /// Parse a list of operands with types. 82 /// 83 /// operand-and-type-list ::= `(` ssa-id-and-type-list `)` 84 /// ssa-id-and-type-list ::= ssa-id-and-type | 85 /// ssa-id-and-type `,` ssa-id-and-type-list 86 /// ssa-id-and-type ::= ssa-id `:` type 87 static ParseResult 88 parseOperandAndTypeList(OpAsmParser &parser, 89 SmallVectorImpl<OpAsmParser::OperandType> &operands, 90 SmallVectorImpl<Type> &types) { 91 return parser.parseCommaSeparatedList( 92 OpAsmParser::Delimiter::Paren, [&]() -> ParseResult { 93 OpAsmParser::OperandType operand; 94 Type type; 95 if (parser.parseOperand(operand) || parser.parseColonType(type)) 96 return failure(); 97 operands.push_back(operand); 98 types.push_back(type); 99 return success(); 100 }); 101 } 102 103 /// Print an operand and type list with parentheses 104 static void printOperandAndTypeList(OpAsmPrinter &p, OperandRange operands) { 105 p << "("; 106 llvm::interleaveComma( 107 operands, p, [&](const Value &v) { p << v << " : " << v.getType(); }); 108 p << ") "; 109 } 110 111 /// Print data variables corresponding to a data-sharing clause `name` 112 static void printDataVars(OpAsmPrinter &p, OperandRange operands, 113 StringRef name) { 114 if (!operands.empty()) { 115 p << name; 116 printOperandAndTypeList(p, operands); 117 } 118 } 119 120 //===----------------------------------------------------------------------===// 121 // Parser and printer for Allocate Clause 122 //===----------------------------------------------------------------------===// 123 124 /// Parse an allocate clause with allocators and a list of operands with types. 125 /// 126 /// allocate ::= `allocate` `(` allocate-operand-list `)` 127 /// allocate-operand-list :: = allocate-operand | 128 /// allocator-operand `,` allocate-operand-list 129 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type 130 /// ssa-id-and-type ::= ssa-id `:` type 131 static ParseResult parseAllocateAndAllocator( 132 OpAsmParser &parser, 133 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate, 134 SmallVectorImpl<Type> &typesAllocate, 135 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator, 136 SmallVectorImpl<Type> &typesAllocator) { 137 138 return parser.parseCommaSeparatedList( 139 OpAsmParser::Delimiter::Paren, [&]() -> ParseResult { 140 OpAsmParser::OperandType operand; 141 Type type; 142 if (parser.parseOperand(operand) || parser.parseColonType(type)) 143 return failure(); 144 operandsAllocator.push_back(operand); 145 typesAllocator.push_back(type); 146 if (parser.parseArrow()) 147 return failure(); 148 if (parser.parseOperand(operand) || parser.parseColonType(type)) 149 return failure(); 150 151 operandsAllocate.push_back(operand); 152 typesAllocate.push_back(type); 153 return success(); 154 }); 155 } 156 157 /// Print allocate clause 158 static void printAllocateAndAllocator(OpAsmPrinter &p, 159 OperandRange varsAllocate, 160 OperandRange varsAllocator) { 161 p << "allocate("; 162 for (unsigned i = 0; i < varsAllocate.size(); ++i) { 163 std::string separator = i == varsAllocate.size() - 1 ? ") " : ", "; 164 p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> "; 165 p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator; 166 } 167 } 168 169 static LogicalResult verifyParallelOp(ParallelOp op) { 170 if (op.allocate_vars().size() != op.allocators_vars().size()) 171 return op.emitError( 172 "expected equal sizes for allocate and allocator variables"); 173 return success(); 174 } 175 176 static void printParallelOp(OpAsmPrinter &p, ParallelOp op) { 177 p << " "; 178 if (auto ifCond = op.if_expr_var()) 179 p << "if(" << ifCond << " : " << ifCond.getType() << ") "; 180 181 if (auto threads = op.num_threads_var()) 182 p << "num_threads(" << threads << " : " << threads.getType() << ") "; 183 184 printDataVars(p, op.private_vars(), "private"); 185 printDataVars(p, op.firstprivate_vars(), "firstprivate"); 186 printDataVars(p, op.shared_vars(), "shared"); 187 printDataVars(p, op.copyin_vars(), "copyin"); 188 189 if (!op.allocate_vars().empty()) 190 printAllocateAndAllocator(p, op.allocate_vars(), op.allocators_vars()); 191 192 if (auto def = op.default_val()) 193 p << "default(" << stringifyClauseDefault(*def).drop_front(3) << ") "; 194 195 if (auto bind = op.proc_bind_val()) 196 p << "proc_bind(" << stringifyClauseProcBindKind(*bind) << ") "; 197 198 p << ' '; 199 p.printRegion(op.getRegion()); 200 } 201 202 //===----------------------------------------------------------------------===// 203 // Parser and printer for Linear Clause 204 //===----------------------------------------------------------------------===// 205 206 /// linear ::= `linear` `(` linear-list `)` 207 /// linear-list := linear-val | linear-val linear-list 208 /// linear-val := ssa-id-and-type `=` ssa-id-and-type 209 static ParseResult 210 parseLinearClause(OpAsmParser &parser, 211 SmallVectorImpl<OpAsmParser::OperandType> &vars, 212 SmallVectorImpl<Type> &types, 213 SmallVectorImpl<OpAsmParser::OperandType> &stepVars) { 214 if (parser.parseLParen()) 215 return failure(); 216 217 do { 218 OpAsmParser::OperandType var; 219 Type type; 220 OpAsmParser::OperandType stepVar; 221 if (parser.parseOperand(var) || parser.parseEqual() || 222 parser.parseOperand(stepVar) || parser.parseColonType(type)) 223 return failure(); 224 225 vars.push_back(var); 226 types.push_back(type); 227 stepVars.push_back(stepVar); 228 } while (succeeded(parser.parseOptionalComma())); 229 230 if (parser.parseRParen()) 231 return failure(); 232 233 return success(); 234 } 235 236 /// Print Linear Clause 237 static void printLinearClause(OpAsmPrinter &p, OperandRange linearVars, 238 OperandRange linearStepVars) { 239 size_t linearVarsSize = linearVars.size(); 240 p << "linear("; 241 for (unsigned i = 0; i < linearVarsSize; ++i) { 242 std::string separator = i == linearVarsSize - 1 ? ") " : ", "; 243 p << linearVars[i]; 244 if (linearStepVars.size() > i) 245 p << " = " << linearStepVars[i]; 246 p << " : " << linearVars[i].getType() << separator; 247 } 248 } 249 250 //===----------------------------------------------------------------------===// 251 // Parser and printer for Schedule Clause 252 //===----------------------------------------------------------------------===// 253 254 static ParseResult 255 verifyScheduleModifiers(OpAsmParser &parser, 256 SmallVectorImpl<SmallString<12>> &modifiers) { 257 if (modifiers.size() > 2) 258 return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)"; 259 for (const auto &mod : modifiers) { 260 // Translate the string. If it has no value, then it was not a valid 261 // modifier! 262 auto symbol = symbolizeScheduleModifier(mod); 263 if (!symbol.hasValue()) 264 return parser.emitError(parser.getNameLoc()) 265 << " unknown modifier type: " << mod; 266 } 267 268 // If we have one modifier that is "simd", then stick a "none" modiifer in 269 // index 0. 270 if (modifiers.size() == 1) { 271 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) { 272 modifiers.push_back(modifiers[0]); 273 modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none); 274 } 275 } else if (modifiers.size() == 2) { 276 // If there are two modifier: 277 // First modifier should not be simd, second one should be simd 278 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd || 279 symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd) 280 return parser.emitError(parser.getNameLoc()) 281 << " incorrect modifier order"; 282 } 283 return success(); 284 } 285 286 /// schedule ::= `schedule` `(` sched-list `)` 287 /// sched-list ::= sched-val | sched-val sched-list | 288 /// sched-val `,` sched-modifier 289 /// sched-val ::= sched-with-chunk | sched-wo-chunk 290 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)? 291 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided` 292 /// sched-wo-chunk ::= `auto` | `runtime` 293 /// sched-modifier ::= sched-mod-val | sched-mod-val `,` sched-mod-val 294 /// sched-mod-val ::= `monotonic` | `nonmonotonic` | `simd` | `none` 295 static ParseResult 296 parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule, 297 SmallVectorImpl<SmallString<12>> &modifiers, 298 Optional<OpAsmParser::OperandType> &chunkSize) { 299 if (parser.parseLParen()) 300 return failure(); 301 302 StringRef keyword; 303 if (parser.parseKeyword(&keyword)) 304 return failure(); 305 306 schedule = keyword; 307 if (keyword == "static" || keyword == "dynamic" || keyword == "guided") { 308 if (succeeded(parser.parseOptionalEqual())) { 309 chunkSize = OpAsmParser::OperandType{}; 310 if (parser.parseOperand(*chunkSize)) 311 return failure(); 312 } else { 313 chunkSize = llvm::NoneType::None; 314 } 315 } else if (keyword == "auto" || keyword == "runtime") { 316 chunkSize = llvm::NoneType::None; 317 } else { 318 return parser.emitError(parser.getNameLoc()) << " expected schedule kind"; 319 } 320 321 // If there is a comma, we have one or more modifiers.. 322 while (succeeded(parser.parseOptionalComma())) { 323 StringRef mod; 324 if (parser.parseKeyword(&mod)) 325 return failure(); 326 modifiers.push_back(mod); 327 } 328 329 if (parser.parseRParen()) 330 return failure(); 331 332 if (verifyScheduleModifiers(parser, modifiers)) 333 return failure(); 334 335 return success(); 336 } 337 338 /// Print schedule clause 339 static void printScheduleClause(OpAsmPrinter &p, ClauseScheduleKind sched, 340 Optional<ScheduleModifier> modifier, bool simd, 341 Value scheduleChunkVar) { 342 p << "schedule(" << stringifyClauseScheduleKind(sched).lower(); 343 if (scheduleChunkVar) 344 p << " = " << scheduleChunkVar; 345 if (modifier) 346 p << ", " << stringifyScheduleModifier(*modifier); 347 if (simd) 348 p << ", simd"; 349 p << ") "; 350 } 351 352 //===----------------------------------------------------------------------===// 353 // Parser, printer and verifier for ReductionVarList 354 //===----------------------------------------------------------------------===// 355 356 /// reduction ::= `reduction` `(` reduction-entry-list `)` 357 /// reduction-entry-list ::= reduction-entry 358 /// | reduction-entry-list `,` reduction-entry 359 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type 360 static ParseResult 361 parseReductionVarList(OpAsmParser &parser, 362 SmallVectorImpl<SymbolRefAttr> &symbols, 363 SmallVectorImpl<OpAsmParser::OperandType> &operands, 364 SmallVectorImpl<Type> &types) { 365 if (failed(parser.parseLParen())) 366 return failure(); 367 368 do { 369 if (parser.parseAttribute(symbols.emplace_back()) || parser.parseArrow() || 370 parser.parseOperand(operands.emplace_back()) || 371 parser.parseColonType(types.emplace_back())) 372 return failure(); 373 } while (succeeded(parser.parseOptionalComma())); 374 return parser.parseRParen(); 375 } 376 377 /// Print Reduction clause 378 static void printReductionVarList(OpAsmPrinter &p, 379 Optional<ArrayAttr> reductions, 380 OperandRange reductionVars) { 381 p << "reduction("; 382 for (unsigned i = 0, e = reductions->size(); i < e; ++i) { 383 if (i != 0) 384 p << ", "; 385 p << (*reductions)[i] << " -> " << reductionVars[i] << " : " 386 << reductionVars[i].getType(); 387 } 388 p << ") "; 389 } 390 391 /// Verifies Reduction Clause 392 static LogicalResult verifyReductionVarList(Operation *op, 393 Optional<ArrayAttr> reductions, 394 OperandRange reductionVars) { 395 if (!reductionVars.empty()) { 396 if (!reductions || reductions->size() != reductionVars.size()) 397 return op->emitOpError() 398 << "expected as many reduction symbol references " 399 "as reduction variables"; 400 } else { 401 if (reductions) 402 return op->emitOpError() << "unexpected reduction symbol references"; 403 return success(); 404 } 405 406 DenseSet<Value> accumulators; 407 for (auto args : llvm::zip(reductionVars, *reductions)) { 408 Value accum = std::get<0>(args); 409 410 if (!accumulators.insert(accum).second) 411 return op->emitOpError() << "accumulator variable used more than once"; 412 413 Type varType = accum.getType().cast<PointerLikeType>(); 414 auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>(); 415 auto decl = 416 SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef); 417 if (!decl) 418 return op->emitOpError() << "expected symbol reference " << symbolRef 419 << " to point to a reduction declaration"; 420 421 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType) 422 return op->emitOpError() 423 << "expected accumulator (" << varType 424 << ") to be the same type as reduction declaration (" 425 << decl.getAccumulatorType() << ")"; 426 } 427 428 return success(); 429 } 430 431 //===----------------------------------------------------------------------===// 432 // Parser, printer and verifier for Synchronization Hint (2.17.12) 433 //===----------------------------------------------------------------------===// 434 435 /// Parses a Synchronization Hint clause. The value of hint is an integer 436 /// which is a combination of different hints from `omp_sync_hint_t`. 437 /// 438 /// hint-clause = `hint` `(` hint-value `)` 439 static ParseResult parseSynchronizationHint(OpAsmParser &parser, 440 IntegerAttr &hintAttr, 441 bool parseKeyword = true) { 442 if (parseKeyword && failed(parser.parseOptionalKeyword("hint"))) { 443 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0); 444 return success(); 445 } 446 447 if (failed(parser.parseLParen())) 448 return failure(); 449 StringRef hintKeyword; 450 int64_t hint = 0; 451 do { 452 if (failed(parser.parseKeyword(&hintKeyword))) 453 return failure(); 454 if (hintKeyword == "uncontended") 455 hint |= 1; 456 else if (hintKeyword == "contended") 457 hint |= 2; 458 else if (hintKeyword == "nonspeculative") 459 hint |= 4; 460 else if (hintKeyword == "speculative") 461 hint |= 8; 462 else 463 return parser.emitError(parser.getCurrentLocation()) 464 << hintKeyword << " is not a valid hint"; 465 } while (succeeded(parser.parseOptionalComma())); 466 if (failed(parser.parseRParen())) 467 return failure(); 468 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint); 469 return success(); 470 } 471 472 /// Prints a Synchronization Hint clause 473 static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, 474 IntegerAttr hintAttr) { 475 int64_t hint = hintAttr.getInt(); 476 477 if (hint == 0) 478 return; 479 480 // Helper function to get n-th bit from the right end of `value` 481 auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; 482 483 bool uncontended = bitn(hint, 0); 484 bool contended = bitn(hint, 1); 485 bool nonspeculative = bitn(hint, 2); 486 bool speculative = bitn(hint, 3); 487 488 SmallVector<StringRef> hints; 489 if (uncontended) 490 hints.push_back("uncontended"); 491 if (contended) 492 hints.push_back("contended"); 493 if (nonspeculative) 494 hints.push_back("nonspeculative"); 495 if (speculative) 496 hints.push_back("speculative"); 497 498 p << "hint("; 499 llvm::interleaveComma(hints, p); 500 p << ") "; 501 } 502 503 /// Verifies a synchronization hint clause 504 static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) { 505 506 // Helper function to get n-th bit from the right end of `value` 507 auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; 508 509 bool uncontended = bitn(hint, 0); 510 bool contended = bitn(hint, 1); 511 bool nonspeculative = bitn(hint, 2); 512 bool speculative = bitn(hint, 3); 513 514 if (uncontended && contended) 515 return op->emitOpError() << "the hints omp_sync_hint_uncontended and " 516 "omp_sync_hint_contended cannot be combined"; 517 if (nonspeculative && speculative) 518 return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and " 519 "omp_sync_hint_speculative cannot be combined."; 520 return success(); 521 } 522 523 enum ClauseType { 524 ifClause, 525 numThreadsClause, 526 privateClause, 527 firstprivateClause, 528 lastprivateClause, 529 sharedClause, 530 copyinClause, 531 allocateClause, 532 defaultClause, 533 procBindClause, 534 reductionClause, 535 nowaitClause, 536 linearClause, 537 scheduleClause, 538 collapseClause, 539 orderClause, 540 orderedClause, 541 memoryOrderClause, 542 hintClause, 543 COUNT 544 }; 545 546 //===----------------------------------------------------------------------===// 547 // Parser for Clause List 548 //===----------------------------------------------------------------------===// 549 550 /// Parse a clause attribute `(` $value `)`. 551 template <typename ClauseAttr> 552 static ParseResult parseClauseAttr(AsmParser &parser, OperationState &state, 553 StringRef attrName, StringRef name) { 554 using ClauseT = decltype(std::declval<ClauseAttr>().getValue()); 555 StringRef enumStr; 556 llvm::SMLoc loc = parser.getCurrentLocation(); 557 if (parser.parseLParen() || parser.parseKeyword(&enumStr) || 558 parser.parseRParen()) 559 return failure(); 560 if (Optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) { 561 auto attr = ClauseAttr::get(parser.getContext(), *enumValue); 562 state.addAttribute(attrName, attr); 563 return success(); 564 } 565 return parser.emitError(loc, "invalid ") << name << " kind"; 566 } 567 568 /// Parse a list of clauses. The clauses can appear in any order, but their 569 /// operand segment indices are in the same order that they are passed in the 570 /// `clauses` list. The operand segments are added over the prevSegments 571 572 /// clause-list ::= clause clause-list | empty 573 /// clause ::= if | num-threads | private | firstprivate | lastprivate | 574 /// shared | copyin | allocate | default | proc-bind | reduction | 575 /// nowait | linear | schedule | collapse | order | ordered | 576 /// inclusive 577 /// if ::= `if` `(` ssa-id-and-type `)` 578 /// num-threads ::= `num_threads` `(` ssa-id-and-type `)` 579 /// private ::= `private` operand-and-type-list 580 /// firstprivate ::= `firstprivate` operand-and-type-list 581 /// lastprivate ::= `lastprivate` operand-and-type-list 582 /// shared ::= `shared` operand-and-type-list 583 /// copyin ::= `copyin` operand-and-type-list 584 /// allocate ::= `allocate` `(` allocate-operand-list `)` 585 /// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`) 586 /// proc-bind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)` 587 /// reduction ::= `reduction` `(` reduction-entry-list `)` 588 /// nowait ::= `nowait` 589 /// linear ::= `linear` `(` linear-list `)` 590 /// schedule ::= `schedule` `(` sched-list `)` 591 /// collapse ::= `collapse` `(` ssa-id-and-type `)` 592 /// order ::= `order` `(` `concurrent` `)` 593 /// ordered ::= `ordered` `(` ssa-id-and-type `)` 594 /// inclusive ::= `inclusive` 595 /// 596 /// Note that each clause can only appear once in the clase-list. 597 static ParseResult parseClauses(OpAsmParser &parser, OperationState &result, 598 SmallVectorImpl<ClauseType> &clauses, 599 SmallVectorImpl<int> &segments) { 600 601 // Check done[clause] to see if it has been parsed already 602 llvm::BitVector done(ClauseType::COUNT, false); 603 604 // See pos[clause] to get position of clause in operand segments 605 SmallVector<int> pos(ClauseType::COUNT, -1); 606 607 // Stores the last parsed clause keyword 608 StringRef clauseKeyword; 609 StringRef opName = result.name.getStringRef(); 610 611 // Containers for storing operands, types and attributes for various clauses 612 std::pair<OpAsmParser::OperandType, Type> ifCond; 613 std::pair<OpAsmParser::OperandType, Type> numThreads; 614 615 SmallVector<OpAsmParser::OperandType> privates, firstprivates, lastprivates, 616 shareds, copyins; 617 SmallVector<Type> privateTypes, firstprivateTypes, lastprivateTypes, 618 sharedTypes, copyinTypes; 619 620 SmallVector<OpAsmParser::OperandType> allocates, allocators; 621 SmallVector<Type> allocateTypes, allocatorTypes; 622 623 SmallVector<SymbolRefAttr> reductionSymbols; 624 SmallVector<OpAsmParser::OperandType> reductionVars; 625 SmallVector<Type> reductionVarTypes; 626 627 SmallVector<OpAsmParser::OperandType> linears; 628 SmallVector<Type> linearTypes; 629 SmallVector<OpAsmParser::OperandType> linearSteps; 630 631 SmallString<8> schedule; 632 SmallVector<SmallString<12>> modifiers; 633 Optional<OpAsmParser::OperandType> scheduleChunkSize; 634 635 // Compute the position of clauses in operand segments 636 int currPos = 0; 637 for (ClauseType clause : clauses) { 638 639 // Skip the following clauses - they do not take any position in operand 640 // segments 641 if (clause == defaultClause || clause == procBindClause || 642 clause == nowaitClause || clause == collapseClause || 643 clause == orderClause || clause == orderedClause) 644 continue; 645 646 pos[clause] = currPos++; 647 648 // For the following clauses, two positions are reserved in the operand 649 // segments 650 if (clause == allocateClause || clause == linearClause) 651 currPos++; 652 } 653 654 SmallVector<int> clauseSegments(currPos); 655 656 // Helper function to check if a clause is allowed/repeated or not 657 auto checkAllowed = [&](ClauseType clause) -> ParseResult { 658 if (!llvm::is_contained(clauses, clause)) 659 return parser.emitError(parser.getCurrentLocation()) 660 << clauseKeyword << " is not a valid clause for the " << opName 661 << " operation"; 662 if (done[clause]) 663 return parser.emitError(parser.getCurrentLocation()) 664 << "at most one " << clauseKeyword << " clause can appear on the " 665 << opName << " operation"; 666 done[clause] = true; 667 return success(); 668 }; 669 670 while (succeeded(parser.parseOptionalKeyword(&clauseKeyword))) { 671 if (clauseKeyword == "if") { 672 if (checkAllowed(ifClause) || parser.parseLParen() || 673 parser.parseOperand(ifCond.first) || 674 parser.parseColonType(ifCond.second) || parser.parseRParen()) 675 return failure(); 676 clauseSegments[pos[ifClause]] = 1; 677 } else if (clauseKeyword == "num_threads") { 678 if (checkAllowed(numThreadsClause) || parser.parseLParen() || 679 parser.parseOperand(numThreads.first) || 680 parser.parseColonType(numThreads.second) || parser.parseRParen()) 681 return failure(); 682 clauseSegments[pos[numThreadsClause]] = 1; 683 } else if (clauseKeyword == "private") { 684 if (checkAllowed(privateClause) || 685 parseOperandAndTypeList(parser, privates, privateTypes)) 686 return failure(); 687 clauseSegments[pos[privateClause]] = privates.size(); 688 } else if (clauseKeyword == "firstprivate") { 689 if (checkAllowed(firstprivateClause) || 690 parseOperandAndTypeList(parser, firstprivates, firstprivateTypes)) 691 return failure(); 692 clauseSegments[pos[firstprivateClause]] = firstprivates.size(); 693 } else if (clauseKeyword == "lastprivate") { 694 if (checkAllowed(lastprivateClause) || 695 parseOperandAndTypeList(parser, lastprivates, lastprivateTypes)) 696 return failure(); 697 clauseSegments[pos[lastprivateClause]] = lastprivates.size(); 698 } else if (clauseKeyword == "shared") { 699 if (checkAllowed(sharedClause) || 700 parseOperandAndTypeList(parser, shareds, sharedTypes)) 701 return failure(); 702 clauseSegments[pos[sharedClause]] = shareds.size(); 703 } else if (clauseKeyword == "copyin") { 704 if (checkAllowed(copyinClause) || 705 parseOperandAndTypeList(parser, copyins, copyinTypes)) 706 return failure(); 707 clauseSegments[pos[copyinClause]] = copyins.size(); 708 } else if (clauseKeyword == "allocate") { 709 if (checkAllowed(allocateClause) || 710 parseAllocateAndAllocator(parser, allocates, allocateTypes, 711 allocators, allocatorTypes)) 712 return failure(); 713 clauseSegments[pos[allocateClause]] = allocates.size(); 714 clauseSegments[pos[allocateClause] + 1] = allocators.size(); 715 } else if (clauseKeyword == "default") { 716 StringRef defval; 717 llvm::SMLoc loc = parser.getCurrentLocation(); 718 if (checkAllowed(defaultClause) || parser.parseLParen() || 719 parser.parseKeyword(&defval) || parser.parseRParen()) 720 return failure(); 721 // The def prefix is required for the attribute as "private" is a keyword 722 // in C++. 723 if (Optional<ClauseDefault> def = 724 symbolizeClauseDefault(("def" + defval).str())) { 725 result.addAttribute("default_val", 726 ClauseDefaultAttr::get(parser.getContext(), *def)); 727 } else { 728 return parser.emitError(loc, "invalid default clause"); 729 } 730 } else if (clauseKeyword == "proc_bind") { 731 if (checkAllowed(procBindClause) || 732 parseClauseAttr<ClauseProcBindKindAttr>(parser, result, 733 "proc_bind_val", "proc bind")) 734 return failure(); 735 } else if (clauseKeyword == "reduction") { 736 if (checkAllowed(reductionClause) || 737 parseReductionVarList(parser, reductionSymbols, reductionVars, 738 reductionVarTypes)) 739 return failure(); 740 clauseSegments[pos[reductionClause]] = reductionVars.size(); 741 } else if (clauseKeyword == "nowait") { 742 if (checkAllowed(nowaitClause)) 743 return failure(); 744 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 745 result.addAttribute("nowait", attr); 746 } else if (clauseKeyword == "linear") { 747 if (checkAllowed(linearClause) || 748 parseLinearClause(parser, linears, linearTypes, linearSteps)) 749 return failure(); 750 clauseSegments[pos[linearClause]] = linears.size(); 751 clauseSegments[pos[linearClause] + 1] = linearSteps.size(); 752 } else if (clauseKeyword == "schedule") { 753 if (checkAllowed(scheduleClause) || 754 parseScheduleClause(parser, schedule, modifiers, scheduleChunkSize)) 755 return failure(); 756 if (scheduleChunkSize) { 757 clauseSegments[pos[scheduleClause]] = 1; 758 } 759 } else if (clauseKeyword == "collapse") { 760 auto type = parser.getBuilder().getI64Type(); 761 mlir::IntegerAttr attr; 762 if (checkAllowed(collapseClause) || parser.parseLParen() || 763 parser.parseAttribute(attr, type) || parser.parseRParen()) 764 return failure(); 765 result.addAttribute("collapse_val", attr); 766 } else if (clauseKeyword == "ordered") { 767 mlir::IntegerAttr attr; 768 if (checkAllowed(orderedClause)) 769 return failure(); 770 if (succeeded(parser.parseOptionalLParen())) { 771 auto type = parser.getBuilder().getI64Type(); 772 if (parser.parseAttribute(attr, type) || parser.parseRParen()) 773 return failure(); 774 } else { 775 // Use 0 to represent no ordered parameter was specified 776 attr = parser.getBuilder().getI64IntegerAttr(0); 777 } 778 result.addAttribute("ordered_val", attr); 779 } else if (clauseKeyword == "order") { 780 if (checkAllowed(orderClause) || 781 parseClauseAttr<ClauseOrderKindAttr>(parser, result, "order_val", 782 "order")) 783 return failure(); 784 } else if (clauseKeyword == "memory_order") { 785 if (checkAllowed(memoryOrderClause) || 786 parseClauseAttr<ClauseMemoryOrderKindAttr>( 787 parser, result, "memory_order", "memory order")) 788 return failure(); 789 } else if (clauseKeyword == "hint") { 790 IntegerAttr hint; 791 if (checkAllowed(hintClause) || 792 parseSynchronizationHint(parser, hint, false)) 793 return failure(); 794 result.addAttribute("hint", hint); 795 } else { 796 return parser.emitError(parser.getNameLoc()) 797 << clauseKeyword << " is not a valid clause"; 798 } 799 } 800 801 // Add if parameter. 802 if (done[ifClause] && clauseSegments[pos[ifClause]] && 803 failed( 804 parser.resolveOperand(ifCond.first, ifCond.second, result.operands))) 805 return failure(); 806 807 // Add num_threads parameter. 808 if (done[numThreadsClause] && clauseSegments[pos[numThreadsClause]] && 809 failed(parser.resolveOperand(numThreads.first, numThreads.second, 810 result.operands))) 811 return failure(); 812 813 // Add private parameters. 814 if (done[privateClause] && clauseSegments[pos[privateClause]] && 815 failed(parser.resolveOperands(privates, privateTypes, 816 privates[0].location, result.operands))) 817 return failure(); 818 819 // Add firstprivate parameters. 820 if (done[firstprivateClause] && clauseSegments[pos[firstprivateClause]] && 821 failed(parser.resolveOperands(firstprivates, firstprivateTypes, 822 firstprivates[0].location, 823 result.operands))) 824 return failure(); 825 826 // Add lastprivate parameters. 827 if (done[lastprivateClause] && clauseSegments[pos[lastprivateClause]] && 828 failed(parser.resolveOperands(lastprivates, lastprivateTypes, 829 lastprivates[0].location, result.operands))) 830 return failure(); 831 832 // Add shared parameters. 833 if (done[sharedClause] && clauseSegments[pos[sharedClause]] && 834 failed(parser.resolveOperands(shareds, sharedTypes, shareds[0].location, 835 result.operands))) 836 return failure(); 837 838 // Add copyin parameters. 839 if (done[copyinClause] && clauseSegments[pos[copyinClause]] && 840 failed(parser.resolveOperands(copyins, copyinTypes, copyins[0].location, 841 result.operands))) 842 return failure(); 843 844 // Add allocate parameters. 845 if (done[allocateClause] && clauseSegments[pos[allocateClause]] && 846 failed(parser.resolveOperands(allocates, allocateTypes, 847 allocates[0].location, result.operands))) 848 return failure(); 849 850 // Add allocator parameters. 851 if (done[allocateClause] && clauseSegments[pos[allocateClause] + 1] && 852 failed(parser.resolveOperands(allocators, allocatorTypes, 853 allocators[0].location, result.operands))) 854 return failure(); 855 856 // Add reduction parameters and symbols 857 if (done[reductionClause] && clauseSegments[pos[reductionClause]]) { 858 if (failed(parser.resolveOperands(reductionVars, reductionVarTypes, 859 parser.getNameLoc(), result.operands))) 860 return failure(); 861 862 SmallVector<Attribute> reductions(reductionSymbols.begin(), 863 reductionSymbols.end()); 864 result.addAttribute("reductions", 865 parser.getBuilder().getArrayAttr(reductions)); 866 } 867 868 // Add linear parameters 869 if (done[linearClause] && clauseSegments[pos[linearClause]]) { 870 auto linearStepType = parser.getBuilder().getI32Type(); 871 SmallVector<Type> linearStepTypes(linearSteps.size(), linearStepType); 872 if (failed(parser.resolveOperands(linears, linearTypes, linears[0].location, 873 result.operands)) || 874 failed(parser.resolveOperands(linearSteps, linearStepTypes, 875 linearSteps[0].location, 876 result.operands))) 877 return failure(); 878 } 879 880 // Add schedule parameters 881 if (done[scheduleClause] && !schedule.empty()) { 882 schedule[0] = llvm::toUpper(schedule[0]); 883 if (Optional<ClauseScheduleKind> sched = 884 symbolizeClauseScheduleKind(schedule)) { 885 auto attr = ClauseScheduleKindAttr::get(parser.getContext(), *sched); 886 result.addAttribute("schedule_val", attr); 887 } else { 888 return parser.emitError(parser.getCurrentLocation(), 889 "invalid schedule kind"); 890 } 891 if (!modifiers.empty()) { 892 llvm::SMLoc loc = parser.getCurrentLocation(); 893 if (Optional<ScheduleModifier> mod = 894 symbolizeScheduleModifier(modifiers[0])) { 895 result.addAttribute( 896 "schedule_modifier", 897 ScheduleModifierAttr::get(parser.getContext(), *mod)); 898 } else { 899 return parser.emitError(loc, "invalid schedule modifier"); 900 } 901 // Only SIMD attribute is allowed here! 902 if (modifiers.size() > 1) { 903 assert(symbolizeScheduleModifier(modifiers[1]) == 904 ScheduleModifier::simd); 905 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 906 result.addAttribute("simd_modifier", attr); 907 } 908 } 909 if (scheduleChunkSize) { 910 auto chunkSizeType = parser.getBuilder().getI32Type(); 911 parser.resolveOperand(*scheduleChunkSize, chunkSizeType, result.operands); 912 } 913 } 914 915 segments.insert(segments.end(), clauseSegments.begin(), clauseSegments.end()); 916 917 return success(); 918 } 919 920 /// Parses a parallel operation. 921 /// 922 /// operation ::= `omp.parallel` clause-list 923 /// clause-list ::= clause | clause clause-list 924 /// clause ::= if | num-threads | private | firstprivate | shared | copyin | 925 /// allocate | default | proc-bind 926 /// 927 static ParseResult parseParallelOp(OpAsmParser &parser, 928 OperationState &result) { 929 SmallVector<ClauseType> clauses = { 930 ifClause, numThreadsClause, privateClause, 931 firstprivateClause, sharedClause, copyinClause, 932 allocateClause, defaultClause, procBindClause}; 933 934 SmallVector<int> segments; 935 936 if (failed(parseClauses(parser, result, clauses, segments))) 937 return failure(); 938 939 result.addAttribute("operand_segment_sizes", 940 parser.getBuilder().getI32VectorAttr(segments)); 941 942 Region *body = result.addRegion(); 943 SmallVector<OpAsmParser::OperandType> regionArgs; 944 SmallVector<Type> regionArgTypes; 945 if (parser.parseRegion(*body, regionArgs, regionArgTypes)) 946 return failure(); 947 return success(); 948 } 949 950 //===----------------------------------------------------------------------===// 951 // Parser, printer and verifier for SectionsOp 952 //===----------------------------------------------------------------------===// 953 954 /// Parses an OpenMP Sections operation 955 /// 956 /// sections ::= `omp.sections` clause-list 957 /// clause-list ::= clause clause-list | empty 958 /// clause ::= private | firstprivate | lastprivate | reduction | allocate | 959 /// nowait 960 static ParseResult parseSectionsOp(OpAsmParser &parser, 961 OperationState &result) { 962 963 SmallVector<ClauseType> clauses = {privateClause, firstprivateClause, 964 lastprivateClause, reductionClause, 965 allocateClause, nowaitClause}; 966 967 SmallVector<int> segments; 968 969 if (failed(parseClauses(parser, result, clauses, segments))) 970 return failure(); 971 972 result.addAttribute("operand_segment_sizes", 973 parser.getBuilder().getI32VectorAttr(segments)); 974 975 // Now parse the body. 976 Region *body = result.addRegion(); 977 if (parser.parseRegion(*body)) 978 return failure(); 979 return success(); 980 } 981 982 static void printSectionsOp(OpAsmPrinter &p, SectionsOp op) { 983 p << " "; 984 printDataVars(p, op.private_vars(), "private"); 985 printDataVars(p, op.firstprivate_vars(), "firstprivate"); 986 printDataVars(p, op.lastprivate_vars(), "lastprivate"); 987 988 if (!op.reduction_vars().empty()) 989 printReductionVarList(p, op.reductions(), op.reduction_vars()); 990 991 if (!op.allocate_vars().empty()) 992 printAllocateAndAllocator(p, op.allocate_vars(), op.allocators_vars()); 993 994 if (op.nowait()) 995 p << "nowait"; 996 997 p << ' '; 998 p.printRegion(op.region()); 999 } 1000 1001 static LogicalResult verifySectionsOp(SectionsOp op) { 1002 1003 // A list item may not appear in more than one clause on the same directive, 1004 // except that it may be specified in both firstprivate and lastprivate 1005 // clauses. 1006 for (auto var : op.private_vars()) { 1007 if (llvm::is_contained(op.firstprivate_vars(), var)) 1008 return op.emitOpError() 1009 << "operand used in both private and firstprivate clauses"; 1010 if (llvm::is_contained(op.lastprivate_vars(), var)) 1011 return op.emitOpError() 1012 << "operand used in both private and lastprivate clauses"; 1013 } 1014 1015 if (op.allocate_vars().size() != op.allocators_vars().size()) 1016 return op.emitError( 1017 "expected equal sizes for allocate and allocator variables"); 1018 1019 for (auto &inst : *op.region().begin()) { 1020 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) 1021 op.emitOpError() 1022 << "expected omp.section op or terminator op inside region"; 1023 } 1024 1025 return verifyReductionVarList(op, op.reductions(), op.reduction_vars()); 1026 } 1027 1028 /// Parses an OpenMP Workshare Loop operation 1029 /// 1030 /// wsloop ::= `omp.wsloop` loop-control clause-list 1031 /// loop-control ::= `(` ssa-id-list `)` `:` type `=` loop-bounds 1032 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` inclusive? steps 1033 /// steps := `step` `(`ssa-id-list`)` 1034 /// clause-list ::= clause clause-list | empty 1035 /// clause ::= private | firstprivate | lastprivate | linear | schedule | 1036 // collapse | nowait | ordered | order | reduction 1037 static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) { 1038 1039 // Parse an opening `(` followed by induction variables followed by `)` 1040 SmallVector<OpAsmParser::OperandType> ivs; 1041 if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1, 1042 OpAsmParser::Delimiter::Paren)) 1043 return failure(); 1044 1045 int numIVs = static_cast<int>(ivs.size()); 1046 Type loopVarType; 1047 if (parser.parseColonType(loopVarType)) 1048 return failure(); 1049 1050 // Parse loop bounds. 1051 SmallVector<OpAsmParser::OperandType> lower; 1052 if (parser.parseEqual() || 1053 parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) || 1054 parser.resolveOperands(lower, loopVarType, result.operands)) 1055 return failure(); 1056 1057 SmallVector<OpAsmParser::OperandType> upper; 1058 if (parser.parseKeyword("to") || 1059 parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) || 1060 parser.resolveOperands(upper, loopVarType, result.operands)) 1061 return failure(); 1062 1063 if (succeeded(parser.parseOptionalKeyword("inclusive"))) { 1064 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 1065 result.addAttribute("inclusive", attr); 1066 } 1067 1068 // Parse step values. 1069 SmallVector<OpAsmParser::OperandType> steps; 1070 if (parser.parseKeyword("step") || 1071 parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) || 1072 parser.resolveOperands(steps, loopVarType, result.operands)) 1073 return failure(); 1074 1075 SmallVector<ClauseType> clauses = { 1076 privateClause, firstprivateClause, lastprivateClause, linearClause, 1077 reductionClause, collapseClause, orderClause, orderedClause, 1078 nowaitClause, scheduleClause}; 1079 SmallVector<int> segments{numIVs, numIVs, numIVs}; 1080 if (failed(parseClauses(parser, result, clauses, segments))) 1081 return failure(); 1082 1083 result.addAttribute("operand_segment_sizes", 1084 parser.getBuilder().getI32VectorAttr(segments)); 1085 1086 // Now parse the body. 1087 Region *body = result.addRegion(); 1088 SmallVector<Type> ivTypes(numIVs, loopVarType); 1089 SmallVector<OpAsmParser::OperandType> blockArgs(ivs); 1090 if (parser.parseRegion(*body, blockArgs, ivTypes)) 1091 return failure(); 1092 return success(); 1093 } 1094 1095 static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) { 1096 auto args = op.getRegion().front().getArguments(); 1097 p << " (" << args << ") : " << args[0].getType() << " = (" << op.lowerBound() 1098 << ") to (" << op.upperBound() << ") "; 1099 if (op.inclusive()) { 1100 p << "inclusive "; 1101 } 1102 p << "step (" << op.step() << ") "; 1103 1104 printDataVars(p, op.private_vars(), "private"); 1105 printDataVars(p, op.firstprivate_vars(), "firstprivate"); 1106 printDataVars(p, op.lastprivate_vars(), "lastprivate"); 1107 1108 if (!op.linear_vars().empty()) 1109 printLinearClause(p, op.linear_vars(), op.linear_step_vars()); 1110 1111 if (auto sched = op.schedule_val()) 1112 printScheduleClause(p, sched.getValue(), op.schedule_modifier(), 1113 op.simd_modifier(), op.schedule_chunk_var()); 1114 1115 if (auto collapse = op.collapse_val()) 1116 p << "collapse(" << collapse << ") "; 1117 1118 if (op.nowait()) 1119 p << "nowait "; 1120 1121 if (auto ordered = op.ordered_val()) 1122 p << "ordered(" << ordered << ") "; 1123 1124 if (auto order = op.order_val()) 1125 p << "order(" << stringifyClauseOrderKind(*order) << ") "; 1126 1127 if (!op.reduction_vars().empty()) 1128 printReductionVarList(p, op.reductions(), op.reduction_vars()); 1129 1130 p << ' '; 1131 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 1132 } 1133 1134 //===----------------------------------------------------------------------===// 1135 // ReductionOp 1136 //===----------------------------------------------------------------------===// 1137 1138 static ParseResult parseAtomicReductionRegion(OpAsmParser &parser, 1139 Region ®ion) { 1140 if (parser.parseOptionalKeyword("atomic")) 1141 return success(); 1142 return parser.parseRegion(region); 1143 } 1144 1145 static void printAtomicReductionRegion(OpAsmPrinter &printer, 1146 ReductionDeclareOp op, Region ®ion) { 1147 if (region.empty()) 1148 return; 1149 printer << "atomic "; 1150 printer.printRegion(region); 1151 } 1152 1153 static LogicalResult verifyReductionDeclareOp(ReductionDeclareOp op) { 1154 if (op.initializerRegion().empty()) 1155 return op.emitOpError() << "expects non-empty initializer region"; 1156 Block &initializerEntryBlock = op.initializerRegion().front(); 1157 if (initializerEntryBlock.getNumArguments() != 1 || 1158 initializerEntryBlock.getArgument(0).getType() != op.type()) { 1159 return op.emitOpError() << "expects initializer region with one argument " 1160 "of the reduction type"; 1161 } 1162 1163 for (YieldOp yieldOp : op.initializerRegion().getOps<YieldOp>()) { 1164 if (yieldOp.results().size() != 1 || 1165 yieldOp.results().getTypes()[0] != op.type()) 1166 return op.emitOpError() << "expects initializer region to yield a value " 1167 "of the reduction type"; 1168 } 1169 1170 if (op.reductionRegion().empty()) 1171 return op.emitOpError() << "expects non-empty reduction region"; 1172 Block &reductionEntryBlock = op.reductionRegion().front(); 1173 if (reductionEntryBlock.getNumArguments() != 2 || 1174 reductionEntryBlock.getArgumentTypes()[0] != 1175 reductionEntryBlock.getArgumentTypes()[1] || 1176 reductionEntryBlock.getArgumentTypes()[0] != op.type()) 1177 return op.emitOpError() << "expects reduction region with two arguments of " 1178 "the reduction type"; 1179 for (YieldOp yieldOp : op.reductionRegion().getOps<YieldOp>()) { 1180 if (yieldOp.results().size() != 1 || 1181 yieldOp.results().getTypes()[0] != op.type()) 1182 return op.emitOpError() << "expects reduction region to yield a value " 1183 "of the reduction type"; 1184 } 1185 1186 if (op.atomicReductionRegion().empty()) 1187 return success(); 1188 1189 Block &atomicReductionEntryBlock = op.atomicReductionRegion().front(); 1190 if (atomicReductionEntryBlock.getNumArguments() != 2 || 1191 atomicReductionEntryBlock.getArgumentTypes()[0] != 1192 atomicReductionEntryBlock.getArgumentTypes()[1]) 1193 return op.emitOpError() << "expects atomic reduction region with two " 1194 "arguments of the same type"; 1195 auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0] 1196 .dyn_cast<PointerLikeType>(); 1197 if (!ptrType || ptrType.getElementType() != op.type()) 1198 return op.emitOpError() << "expects atomic reduction region arguments to " 1199 "be accumulators containing the reduction type"; 1200 return success(); 1201 } 1202 1203 static LogicalResult verifyReductionOp(ReductionOp op) { 1204 // TODO: generalize this to an op interface when there is more than one op 1205 // that supports reductions. 1206 auto container = op->getParentOfType<WsLoopOp>(); 1207 for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i) 1208 if (container.reduction_vars()[i] == op.accumulator()) 1209 return success(); 1210 1211 return op.emitOpError() << "the accumulator is not used by the parent"; 1212 } 1213 1214 //===----------------------------------------------------------------------===// 1215 // WsLoopOp 1216 //===----------------------------------------------------------------------===// 1217 1218 void WsLoopOp::build(OpBuilder &builder, OperationState &state, 1219 ValueRange lowerBound, ValueRange upperBound, 1220 ValueRange step, ArrayRef<NamedAttribute> attributes) { 1221 build(builder, state, TypeRange(), lowerBound, upperBound, step, 1222 /*privateVars=*/ValueRange(), 1223 /*firstprivateVars=*/ValueRange(), /*lastprivate_vars=*/ValueRange(), 1224 /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(), 1225 /*reduction_vars=*/ValueRange(), /*schedule_val=*/nullptr, 1226 /*schedule_chunk_var=*/nullptr, /*collapse_val=*/nullptr, 1227 /*nowait=*/nullptr, /*ordered_val=*/nullptr, /*order_val=*/nullptr, 1228 /*inclusive=*/nullptr, /*buildBody=*/false); 1229 state.addAttributes(attributes); 1230 } 1231 1232 void WsLoopOp::build(OpBuilder &, OperationState &state, TypeRange resultTypes, 1233 ValueRange operands, ArrayRef<NamedAttribute> attributes) { 1234 state.addOperands(operands); 1235 state.addAttributes(attributes); 1236 (void)state.addRegion(); 1237 assert(resultTypes.empty() && "mismatched number of return types"); 1238 state.addTypes(resultTypes); 1239 } 1240 1241 void WsLoopOp::build(OpBuilder &builder, OperationState &result, 1242 TypeRange typeRange, ValueRange lowerBounds, 1243 ValueRange upperBounds, ValueRange steps, 1244 ValueRange privateVars, ValueRange firstprivateVars, 1245 ValueRange lastprivateVars, ValueRange linearVars, 1246 ValueRange linearStepVars, ValueRange reductionVars, 1247 StringAttr scheduleVal, Value scheduleChunkVar, 1248 IntegerAttr collapseVal, UnitAttr nowait, 1249 IntegerAttr orderedVal, StringAttr orderVal, 1250 UnitAttr inclusive, bool buildBody) { 1251 result.addOperands(lowerBounds); 1252 result.addOperands(upperBounds); 1253 result.addOperands(steps); 1254 result.addOperands(privateVars); 1255 result.addOperands(firstprivateVars); 1256 result.addOperands(linearVars); 1257 result.addOperands(linearStepVars); 1258 if (scheduleChunkVar) 1259 result.addOperands(scheduleChunkVar); 1260 1261 if (scheduleVal) 1262 result.addAttribute("schedule_val", scheduleVal); 1263 if (collapseVal) 1264 result.addAttribute("collapse_val", collapseVal); 1265 if (nowait) 1266 result.addAttribute("nowait", nowait); 1267 if (orderedVal) 1268 result.addAttribute("ordered_val", orderedVal); 1269 if (orderVal) 1270 result.addAttribute("order", orderVal); 1271 if (inclusive) 1272 result.addAttribute("inclusive", inclusive); 1273 result.addAttribute( 1274 WsLoopOp::getOperandSegmentSizeAttr(), 1275 builder.getI32VectorAttr( 1276 {static_cast<int32_t>(lowerBounds.size()), 1277 static_cast<int32_t>(upperBounds.size()), 1278 static_cast<int32_t>(steps.size()), 1279 static_cast<int32_t>(privateVars.size()), 1280 static_cast<int32_t>(firstprivateVars.size()), 1281 static_cast<int32_t>(lastprivateVars.size()), 1282 static_cast<int32_t>(linearVars.size()), 1283 static_cast<int32_t>(linearStepVars.size()), 1284 static_cast<int32_t>(reductionVars.size()), 1285 static_cast<int32_t>(scheduleChunkVar != nullptr ? 1 : 0)})); 1286 1287 Region *bodyRegion = result.addRegion(); 1288 if (buildBody) { 1289 OpBuilder::InsertionGuard guard(builder); 1290 unsigned numIVs = steps.size(); 1291 SmallVector<Type, 8> argTypes(numIVs, steps.getType().front()); 1292 builder.createBlock(bodyRegion, {}, argTypes); 1293 } 1294 } 1295 1296 static LogicalResult verifyWsLoopOp(WsLoopOp op) { 1297 return verifyReductionVarList(op, op.reductions(), op.reduction_vars()); 1298 } 1299 1300 //===----------------------------------------------------------------------===// 1301 // Verifier for critical construct (2.17.1) 1302 //===----------------------------------------------------------------------===// 1303 1304 static LogicalResult verifyCriticalDeclareOp(CriticalDeclareOp op) { 1305 return verifySynchronizationHint(op, op.hint()); 1306 } 1307 1308 static LogicalResult verifyCriticalOp(CriticalOp op) { 1309 1310 if (op.nameAttr()) { 1311 auto symbolRef = op.nameAttr().cast<SymbolRefAttr>(); 1312 auto decl = 1313 SymbolTable::lookupNearestSymbolFrom<CriticalDeclareOp>(op, symbolRef); 1314 if (!decl) { 1315 return op.emitOpError() << "expected symbol reference " << symbolRef 1316 << " to point to a critical declaration"; 1317 } 1318 } 1319 1320 return success(); 1321 } 1322 1323 //===----------------------------------------------------------------------===// 1324 // Verifier for ordered construct 1325 //===----------------------------------------------------------------------===// 1326 1327 static LogicalResult verifyOrderedOp(OrderedOp op) { 1328 auto container = op->getParentOfType<WsLoopOp>(); 1329 if (!container || !container.ordered_valAttr() || 1330 container.ordered_valAttr().getInt() == 0) 1331 return op.emitOpError() << "ordered depend directive must be closely " 1332 << "nested inside a worksharing-loop with ordered " 1333 << "clause with parameter present"; 1334 1335 if (container.ordered_valAttr().getInt() != 1336 (int64_t)op.num_loops_val().getValue()) 1337 return op.emitOpError() << "number of variables in depend clause does not " 1338 << "match number of iteration variables in the " 1339 << "doacross loop"; 1340 1341 return success(); 1342 } 1343 1344 static LogicalResult verifyOrderedRegionOp(OrderedRegionOp op) { 1345 // TODO: The code generation for ordered simd directive is not supported yet. 1346 if (op.simd()) 1347 return failure(); 1348 1349 if (auto container = op->getParentOfType<WsLoopOp>()) { 1350 if (!container.ordered_valAttr() || 1351 container.ordered_valAttr().getInt() != 0) 1352 return op.emitOpError() << "ordered region must be closely nested inside " 1353 << "a worksharing-loop region with an ordered " 1354 << "clause without parameter present"; 1355 } 1356 1357 return success(); 1358 } 1359 1360 //===----------------------------------------------------------------------===// 1361 // AtomicReadOp 1362 //===----------------------------------------------------------------------===// 1363 1364 /// Parser for AtomicReadOp 1365 /// 1366 /// operation ::= `omp.atomic.read` atomic-clause-list address `->` result-type 1367 /// address ::= operand `:` type 1368 static ParseResult parseAtomicReadOp(OpAsmParser &parser, 1369 OperationState &result) { 1370 OpAsmParser::OperandType x, v; 1371 Type addressType; 1372 SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause}; 1373 SmallVector<int> segments; 1374 1375 if (parser.parseOperand(v) || parser.parseEqual() || parser.parseOperand(x) || 1376 parseClauses(parser, result, clauses, segments) || 1377 parser.parseColonType(addressType) || 1378 parser.resolveOperand(x, addressType, result.operands) || 1379 parser.resolveOperand(v, addressType, result.operands)) 1380 return failure(); 1381 1382 return success(); 1383 } 1384 1385 /// Printer for AtomicReadOp 1386 static void printAtomicReadOp(OpAsmPrinter &p, AtomicReadOp op) { 1387 p << " " << op.v() << " = " << op.x() << " "; 1388 if (auto mo = op.memory_order()) 1389 p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") "; 1390 if (op.hintAttr()) 1391 printSynchronizationHint(p << " ", op, op.hintAttr()); 1392 p << ": " << op.x().getType(); 1393 } 1394 1395 /// Verifier for AtomicReadOp 1396 static LogicalResult verifyAtomicReadOp(AtomicReadOp op) { 1397 if (auto mo = op.memory_order()) { 1398 if (*mo == ClauseMemoryOrderKind::acq_rel || 1399 *mo == ClauseMemoryOrderKind::release) { 1400 return op.emitError( 1401 "memory-order must not be acq_rel or release for atomic reads"); 1402 } 1403 } 1404 if (op.x() == op.v()) 1405 return op.emitError( 1406 "read and write must not be to the same location for atomic reads"); 1407 return verifySynchronizationHint(op, op.hint()); 1408 } 1409 1410 //===----------------------------------------------------------------------===// 1411 // AtomicWriteOp 1412 //===----------------------------------------------------------------------===// 1413 1414 /// Parser for AtomicWriteOp 1415 /// 1416 /// operation ::= `omp.atomic.write` atomic-clause-list operands 1417 /// operands ::= address `,` value 1418 /// address ::= operand `:` type 1419 /// value ::= operand `:` type 1420 static ParseResult parseAtomicWriteOp(OpAsmParser &parser, 1421 OperationState &result) { 1422 OpAsmParser::OperandType address, value; 1423 Type addrType, valueType; 1424 SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause}; 1425 SmallVector<int> segments; 1426 1427 if (parser.parseOperand(address) || parser.parseEqual() || 1428 parser.parseOperand(value) || 1429 parseClauses(parser, result, clauses, segments) || 1430 parser.parseColonType(addrType) || parser.parseComma() || 1431 parser.parseType(valueType) || 1432 parser.resolveOperand(address, addrType, result.operands) || 1433 parser.resolveOperand(value, valueType, result.operands)) 1434 return failure(); 1435 return success(); 1436 } 1437 1438 /// Printer for AtomicWriteOp 1439 static void printAtomicWriteOp(OpAsmPrinter &p, AtomicWriteOp op) { 1440 p << " " << op.address() << " = " << op.value() << " "; 1441 if (auto mo = op.memory_order()) 1442 p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") "; 1443 if (op.hintAttr()) 1444 printSynchronizationHint(p, op, op.hintAttr()); 1445 p << ": " << op.address().getType() << ", " << op.value().getType(); 1446 } 1447 1448 /// Verifier for AtomicWriteOp 1449 static LogicalResult verifyAtomicWriteOp(AtomicWriteOp op) { 1450 if (auto mo = op.memory_order()) { 1451 if (*mo == ClauseMemoryOrderKind::acq_rel || 1452 *mo == ClauseMemoryOrderKind::acquire) { 1453 return op.emitError( 1454 "memory-order must not be acq_rel or acquire for atomic writes"); 1455 } 1456 } 1457 return verifySynchronizationHint(op, op.hint()); 1458 } 1459 1460 //===----------------------------------------------------------------------===// 1461 // AtomicUpdateOp 1462 //===----------------------------------------------------------------------===// 1463 1464 /// Parser for AtomicUpdateOp 1465 /// 1466 /// operation ::= `omp.atomic.update` atomic-clause-list region 1467 static ParseResult parseAtomicUpdateOp(OpAsmParser &parser, 1468 OperationState &result) { 1469 SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause}; 1470 SmallVector<int> segments; 1471 OpAsmParser::OperandType x, y, z; 1472 Type xType, exprType; 1473 StringRef binOp; 1474 1475 // x = y `op` z : xtype, exprtype 1476 if (parser.parseOperand(x) || parser.parseEqual() || parser.parseOperand(y) || 1477 parser.parseKeyword(&binOp) || parser.parseOperand(z) || 1478 parseClauses(parser, result, clauses, segments) || parser.parseColon() || 1479 parser.parseType(xType) || parser.parseComma() || 1480 parser.parseType(exprType) || 1481 parser.resolveOperand(x, xType, result.operands)) { 1482 return failure(); 1483 } 1484 1485 auto binOpEnum = AtomicBinOpKindToEnum(binOp.upper()); 1486 if (!binOpEnum) 1487 return parser.emitError(parser.getNameLoc()) 1488 << "invalid atomic bin op in atomic update\n"; 1489 auto attr = 1490 parser.getBuilder().getI64IntegerAttr((int64_t)binOpEnum.getValue()); 1491 result.addAttribute("binop", attr); 1492 1493 OpAsmParser::OperandType expr; 1494 if (x.name == y.name && x.number == y.number) { 1495 expr = z; 1496 result.addAttribute("isXBinopExpr", parser.getBuilder().getUnitAttr()); 1497 } else if (x.name == z.name && x.number == z.number) { 1498 expr = y; 1499 } else { 1500 return parser.emitError(parser.getNameLoc()) 1501 << "atomic update variable " << x.name 1502 << " not found in the RHS of the assignment statement in an" 1503 " atomic.update operation"; 1504 } 1505 return parser.resolveOperand(expr, exprType, result.operands); 1506 } 1507 1508 /// Printer for AtomicUpdateOp 1509 static void printAtomicUpdateOp(OpAsmPrinter &p, AtomicUpdateOp op) { 1510 p << " " << op.x() << " = "; 1511 Value y, z; 1512 if (op.isXBinopExpr()) { 1513 y = op.x(); 1514 z = op.expr(); 1515 } else { 1516 y = op.expr(); 1517 z = op.x(); 1518 } 1519 p << y << " " << AtomicBinOpKindToString(op.binop()).lower() << " " << z 1520 << " "; 1521 if (auto mo = op.memory_order()) 1522 p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") "; 1523 if (op.hintAttr()) 1524 printSynchronizationHint(p, op, op.hintAttr()); 1525 p << ": " << op.x().getType() << ", " << op.expr().getType(); 1526 } 1527 1528 /// Verifier for AtomicUpdateOp 1529 static LogicalResult verifyAtomicUpdateOp(AtomicUpdateOp op) { 1530 if (auto mo = op.memory_order()) { 1531 if (*mo == ClauseMemoryOrderKind::acq_rel || 1532 *mo == ClauseMemoryOrderKind::acquire) { 1533 return op.emitError( 1534 "memory-order must not be acq_rel or acquire for atomic updates"); 1535 } 1536 } 1537 return success(); 1538 } 1539 1540 #define GET_ATTRDEF_CLASSES 1541 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" 1542 1543 #define GET_OP_CLASSES 1544 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 1545