1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the OpenMP dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/IR/Attributes.h" 17 #include "mlir/IR/DialectImplementation.h" 18 #include "mlir/IR/OpImplementation.h" 19 #include "mlir/IR/OperationSupport.h" 20 21 #include "llvm/ADT/BitVector.h" 22 #include "llvm/ADT/SmallString.h" 23 #include "llvm/ADT/StringExtras.h" 24 #include "llvm/ADT/StringRef.h" 25 #include "llvm/ADT/StringSwitch.h" 26 #include "llvm/ADT/TypeSwitch.h" 27 #include <cstddef> 28 29 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc" 30 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc" 31 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc" 32 33 using namespace mlir; 34 using namespace mlir::omp; 35 36 namespace { 37 /// Model for pointer-like types that already provide a `getElementType` method. 38 template <typename T> 39 struct PointerLikeModel 40 : public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> { 41 Type getElementType(Type pointer) const { 42 return pointer.cast<T>().getElementType(); 43 } 44 }; 45 } // namespace 46 47 void OpenMPDialect::initialize() { 48 addOperations< 49 #define GET_OP_LIST 50 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 51 >(); 52 addAttributes< 53 #define GET_ATTRDEF_LIST 54 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" 55 >(); 56 57 LLVM::LLVMPointerType::attachInterface< 58 PointerLikeModel<LLVM::LLVMPointerType>>(*getContext()); 59 MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext()); 60 } 61 62 //===----------------------------------------------------------------------===// 63 // ParallelOp 64 //===----------------------------------------------------------------------===// 65 66 void ParallelOp::build(OpBuilder &builder, OperationState &state, 67 ArrayRef<NamedAttribute> attributes) { 68 ParallelOp::build( 69 builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr, 70 /*default_val=*/nullptr, /*private_vars=*/ValueRange(), 71 /*firstprivate_vars=*/ValueRange(), /*shared_vars=*/ValueRange(), 72 /*copyin_vars=*/ValueRange(), /*allocate_vars=*/ValueRange(), 73 /*allocators_vars=*/ValueRange(), /*proc_bind_val=*/nullptr); 74 state.addAttributes(attributes); 75 } 76 77 //===----------------------------------------------------------------------===// 78 // Parser and printer for Operand and type list 79 //===----------------------------------------------------------------------===// 80 81 /// Parse a list of operands with types. 82 /// 83 /// operand-and-type-list ::= `(` ssa-id-and-type-list `)` 84 /// ssa-id-and-type-list ::= ssa-id-and-type | 85 /// ssa-id-and-type `,` ssa-id-and-type-list 86 /// ssa-id-and-type ::= ssa-id `:` type 87 static ParseResult 88 parseOperandAndTypeList(OpAsmParser &parser, 89 SmallVectorImpl<OpAsmParser::OperandType> &operands, 90 SmallVectorImpl<Type> &types) { 91 return parser.parseCommaSeparatedList( 92 OpAsmParser::Delimiter::Paren, [&]() -> ParseResult { 93 OpAsmParser::OperandType operand; 94 Type type; 95 if (parser.parseOperand(operand) || parser.parseColonType(type)) 96 return failure(); 97 operands.push_back(operand); 98 types.push_back(type); 99 return success(); 100 }); 101 } 102 103 /// Print an operand and type list with parentheses 104 static void printOperandAndTypeList(OpAsmPrinter &p, OperandRange operands) { 105 p << "("; 106 llvm::interleaveComma( 107 operands, p, [&](const Value &v) { p << v << " : " << v.getType(); }); 108 p << ") "; 109 } 110 111 /// Print data variables corresponding to a data-sharing clause `name` 112 static void printDataVars(OpAsmPrinter &p, OperandRange operands, 113 StringRef name) { 114 if (!operands.empty()) { 115 p << name; 116 printOperandAndTypeList(p, operands); 117 } 118 } 119 120 //===----------------------------------------------------------------------===// 121 // Parser and printer for Allocate Clause 122 //===----------------------------------------------------------------------===// 123 124 /// Parse an allocate clause with allocators and a list of operands with types. 125 /// 126 /// allocate ::= `allocate` `(` allocate-operand-list `)` 127 /// allocate-operand-list :: = allocate-operand | 128 /// allocator-operand `,` allocate-operand-list 129 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type 130 /// ssa-id-and-type ::= ssa-id `:` type 131 static ParseResult parseAllocateAndAllocator( 132 OpAsmParser &parser, 133 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate, 134 SmallVectorImpl<Type> &typesAllocate, 135 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator, 136 SmallVectorImpl<Type> &typesAllocator) { 137 138 return parser.parseCommaSeparatedList( 139 OpAsmParser::Delimiter::Paren, [&]() -> ParseResult { 140 OpAsmParser::OperandType operand; 141 Type type; 142 if (parser.parseOperand(operand) || parser.parseColonType(type)) 143 return failure(); 144 operandsAllocator.push_back(operand); 145 typesAllocator.push_back(type); 146 if (parser.parseArrow()) 147 return failure(); 148 if (parser.parseOperand(operand) || parser.parseColonType(type)) 149 return failure(); 150 151 operandsAllocate.push_back(operand); 152 typesAllocate.push_back(type); 153 return success(); 154 }); 155 } 156 157 /// Print allocate clause 158 static void printAllocateAndAllocator(OpAsmPrinter &p, 159 OperandRange varsAllocate, 160 OperandRange varsAllocator) { 161 p << "allocate("; 162 for (unsigned i = 0; i < varsAllocate.size(); ++i) { 163 std::string separator = i == varsAllocate.size() - 1 ? ") " : ", "; 164 p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> "; 165 p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator; 166 } 167 } 168 169 static LogicalResult verifyParallelOp(ParallelOp op) { 170 if (op.allocate_vars().size() != op.allocators_vars().size()) 171 return op.emitError( 172 "expected equal sizes for allocate and allocator variables"); 173 return success(); 174 } 175 176 static void printParallelOp(OpAsmPrinter &p, ParallelOp op) { 177 p << " "; 178 if (auto ifCond = op.if_expr_var()) 179 p << "if(" << ifCond << " : " << ifCond.getType() << ") "; 180 181 if (auto threads = op.num_threads_var()) 182 p << "num_threads(" << threads << " : " << threads.getType() << ") "; 183 184 printDataVars(p, op.private_vars(), "private"); 185 printDataVars(p, op.firstprivate_vars(), "firstprivate"); 186 printDataVars(p, op.shared_vars(), "shared"); 187 printDataVars(p, op.copyin_vars(), "copyin"); 188 189 if (!op.allocate_vars().empty()) 190 printAllocateAndAllocator(p, op.allocate_vars(), op.allocators_vars()); 191 192 if (auto def = op.default_val()) 193 p << "default(" << stringifyClauseDefault(*def).drop_front(3) << ") "; 194 195 if (auto bind = op.proc_bind_val()) 196 p << "proc_bind(" << stringifyClauseProcBindKind(*bind) << ") "; 197 198 p << ' '; 199 p.printRegion(op.getRegion()); 200 } 201 202 //===----------------------------------------------------------------------===// 203 // Parser and printer for Linear Clause 204 //===----------------------------------------------------------------------===// 205 206 /// linear ::= `linear` `(` linear-list `)` 207 /// linear-list := linear-val | linear-val linear-list 208 /// linear-val := ssa-id-and-type `=` ssa-id-and-type 209 static ParseResult 210 parseLinearClause(OpAsmParser &parser, 211 SmallVectorImpl<OpAsmParser::OperandType> &vars, 212 SmallVectorImpl<Type> &types, 213 SmallVectorImpl<OpAsmParser::OperandType> &stepVars) { 214 if (parser.parseLParen()) 215 return failure(); 216 217 do { 218 OpAsmParser::OperandType var; 219 Type type; 220 OpAsmParser::OperandType stepVar; 221 if (parser.parseOperand(var) || parser.parseEqual() || 222 parser.parseOperand(stepVar) || parser.parseColonType(type)) 223 return failure(); 224 225 vars.push_back(var); 226 types.push_back(type); 227 stepVars.push_back(stepVar); 228 } while (succeeded(parser.parseOptionalComma())); 229 230 if (parser.parseRParen()) 231 return failure(); 232 233 return success(); 234 } 235 236 /// Print Linear Clause 237 static void printLinearClause(OpAsmPrinter &p, OperandRange linearVars, 238 OperandRange linearStepVars) { 239 size_t linearVarsSize = linearVars.size(); 240 p << "linear("; 241 for (unsigned i = 0; i < linearVarsSize; ++i) { 242 std::string separator = i == linearVarsSize - 1 ? ") " : ", "; 243 p << linearVars[i]; 244 if (linearStepVars.size() > i) 245 p << " = " << linearStepVars[i]; 246 p << " : " << linearVars[i].getType() << separator; 247 } 248 } 249 250 //===----------------------------------------------------------------------===// 251 // Parser and printer for Schedule Clause 252 //===----------------------------------------------------------------------===// 253 254 static ParseResult 255 verifyScheduleModifiers(OpAsmParser &parser, 256 SmallVectorImpl<SmallString<12>> &modifiers) { 257 if (modifiers.size() > 2) 258 return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)"; 259 for (const auto &mod : modifiers) { 260 // Translate the string. If it has no value, then it was not a valid 261 // modifier! 262 auto symbol = symbolizeScheduleModifier(mod); 263 if (!symbol.hasValue()) 264 return parser.emitError(parser.getNameLoc()) 265 << " unknown modifier type: " << mod; 266 } 267 268 // If we have one modifier that is "simd", then stick a "none" modiifer in 269 // index 0. 270 if (modifiers.size() == 1) { 271 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) { 272 modifiers.push_back(modifiers[0]); 273 modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none); 274 } 275 } else if (modifiers.size() == 2) { 276 // If there are two modifier: 277 // First modifier should not be simd, second one should be simd 278 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd || 279 symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd) 280 return parser.emitError(parser.getNameLoc()) 281 << " incorrect modifier order"; 282 } 283 return success(); 284 } 285 286 /// schedule ::= `schedule` `(` sched-list `)` 287 /// sched-list ::= sched-val | sched-val sched-list | 288 /// sched-val `,` sched-modifier 289 /// sched-val ::= sched-with-chunk | sched-wo-chunk 290 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)? 291 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided` 292 /// sched-wo-chunk ::= `auto` | `runtime` 293 /// sched-modifier ::= sched-mod-val | sched-mod-val `,` sched-mod-val 294 /// sched-mod-val ::= `monotonic` | `nonmonotonic` | `simd` | `none` 295 static ParseResult 296 parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule, 297 SmallVectorImpl<SmallString<12>> &modifiers, 298 Optional<OpAsmParser::OperandType> &chunkSize, 299 Type &chunkType) { 300 if (parser.parseLParen()) 301 return failure(); 302 303 StringRef keyword; 304 if (parser.parseKeyword(&keyword)) 305 return failure(); 306 307 schedule = keyword; 308 if (keyword == "static" || keyword == "dynamic" || keyword == "guided") { 309 if (succeeded(parser.parseOptionalEqual())) { 310 chunkSize = OpAsmParser::OperandType{}; 311 if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType)) 312 return failure(); 313 } else { 314 chunkSize = llvm::NoneType::None; 315 } 316 } else if (keyword == "auto" || keyword == "runtime") { 317 chunkSize = llvm::NoneType::None; 318 } else { 319 return parser.emitError(parser.getNameLoc()) << " expected schedule kind"; 320 } 321 322 // If there is a comma, we have one or more modifiers.. 323 while (succeeded(parser.parseOptionalComma())) { 324 StringRef mod; 325 if (parser.parseKeyword(&mod)) 326 return failure(); 327 modifiers.push_back(mod); 328 } 329 330 if (parser.parseRParen()) 331 return failure(); 332 333 if (verifyScheduleModifiers(parser, modifiers)) 334 return failure(); 335 336 return success(); 337 } 338 339 /// Print schedule clause 340 static void printScheduleClause(OpAsmPrinter &p, ClauseScheduleKind sched, 341 Optional<ScheduleModifier> modifier, bool simd, 342 Value scheduleChunkVar) { 343 p << "schedule(" << stringifyClauseScheduleKind(sched).lower(); 344 if (scheduleChunkVar) 345 p << " = " << scheduleChunkVar << " : " << scheduleChunkVar.getType(); 346 if (modifier) 347 p << ", " << stringifyScheduleModifier(*modifier); 348 if (simd) 349 p << ", simd"; 350 p << ") "; 351 } 352 353 //===----------------------------------------------------------------------===// 354 // Parser, printer and verifier for ReductionVarList 355 //===----------------------------------------------------------------------===// 356 357 /// reduction ::= `reduction` `(` reduction-entry-list `)` 358 /// reduction-entry-list ::= reduction-entry 359 /// | reduction-entry-list `,` reduction-entry 360 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type 361 static ParseResult 362 parseReductionVarList(OpAsmParser &parser, 363 SmallVectorImpl<SymbolRefAttr> &symbols, 364 SmallVectorImpl<OpAsmParser::OperandType> &operands, 365 SmallVectorImpl<Type> &types) { 366 if (failed(parser.parseLParen())) 367 return failure(); 368 369 do { 370 if (parser.parseAttribute(symbols.emplace_back()) || parser.parseArrow() || 371 parser.parseOperand(operands.emplace_back()) || 372 parser.parseColonType(types.emplace_back())) 373 return failure(); 374 } while (succeeded(parser.parseOptionalComma())); 375 return parser.parseRParen(); 376 } 377 378 /// Print Reduction clause 379 static void printReductionVarList(OpAsmPrinter &p, 380 Optional<ArrayAttr> reductions, 381 OperandRange reductionVars) { 382 p << "reduction("; 383 for (unsigned i = 0, e = reductions->size(); i < e; ++i) { 384 if (i != 0) 385 p << ", "; 386 p << (*reductions)[i] << " -> " << reductionVars[i] << " : " 387 << reductionVars[i].getType(); 388 } 389 p << ") "; 390 } 391 392 /// Verifies Reduction Clause 393 static LogicalResult verifyReductionVarList(Operation *op, 394 Optional<ArrayAttr> reductions, 395 OperandRange reductionVars) { 396 if (!reductionVars.empty()) { 397 if (!reductions || reductions->size() != reductionVars.size()) 398 return op->emitOpError() 399 << "expected as many reduction symbol references " 400 "as reduction variables"; 401 } else { 402 if (reductions) 403 return op->emitOpError() << "unexpected reduction symbol references"; 404 return success(); 405 } 406 407 DenseSet<Value> accumulators; 408 for (auto args : llvm::zip(reductionVars, *reductions)) { 409 Value accum = std::get<0>(args); 410 411 if (!accumulators.insert(accum).second) 412 return op->emitOpError() << "accumulator variable used more than once"; 413 414 Type varType = accum.getType().cast<PointerLikeType>(); 415 auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>(); 416 auto decl = 417 SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef); 418 if (!decl) 419 return op->emitOpError() << "expected symbol reference " << symbolRef 420 << " to point to a reduction declaration"; 421 422 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType) 423 return op->emitOpError() 424 << "expected accumulator (" << varType 425 << ") to be the same type as reduction declaration (" 426 << decl.getAccumulatorType() << ")"; 427 } 428 429 return success(); 430 } 431 432 //===----------------------------------------------------------------------===// 433 // Parser, printer and verifier for Synchronization Hint (2.17.12) 434 //===----------------------------------------------------------------------===// 435 436 /// Parses a Synchronization Hint clause. The value of hint is an integer 437 /// which is a combination of different hints from `omp_sync_hint_t`. 438 /// 439 /// hint-clause = `hint` `(` hint-value `)` 440 static ParseResult parseSynchronizationHint(OpAsmParser &parser, 441 IntegerAttr &hintAttr, 442 bool parseKeyword = true) { 443 if (parseKeyword && failed(parser.parseOptionalKeyword("hint"))) { 444 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0); 445 return success(); 446 } 447 448 if (failed(parser.parseLParen())) 449 return failure(); 450 StringRef hintKeyword; 451 int64_t hint = 0; 452 do { 453 if (failed(parser.parseKeyword(&hintKeyword))) 454 return failure(); 455 if (hintKeyword == "uncontended") 456 hint |= 1; 457 else if (hintKeyword == "contended") 458 hint |= 2; 459 else if (hintKeyword == "nonspeculative") 460 hint |= 4; 461 else if (hintKeyword == "speculative") 462 hint |= 8; 463 else 464 return parser.emitError(parser.getCurrentLocation()) 465 << hintKeyword << " is not a valid hint"; 466 } while (succeeded(parser.parseOptionalComma())); 467 if (failed(parser.parseRParen())) 468 return failure(); 469 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint); 470 return success(); 471 } 472 473 /// Prints a Synchronization Hint clause 474 static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, 475 IntegerAttr hintAttr) { 476 int64_t hint = hintAttr.getInt(); 477 478 if (hint == 0) 479 return; 480 481 // Helper function to get n-th bit from the right end of `value` 482 auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; 483 484 bool uncontended = bitn(hint, 0); 485 bool contended = bitn(hint, 1); 486 bool nonspeculative = bitn(hint, 2); 487 bool speculative = bitn(hint, 3); 488 489 SmallVector<StringRef> hints; 490 if (uncontended) 491 hints.push_back("uncontended"); 492 if (contended) 493 hints.push_back("contended"); 494 if (nonspeculative) 495 hints.push_back("nonspeculative"); 496 if (speculative) 497 hints.push_back("speculative"); 498 499 p << "hint("; 500 llvm::interleaveComma(hints, p); 501 p << ") "; 502 } 503 504 /// Verifies a synchronization hint clause 505 static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) { 506 507 // Helper function to get n-th bit from the right end of `value` 508 auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; 509 510 bool uncontended = bitn(hint, 0); 511 bool contended = bitn(hint, 1); 512 bool nonspeculative = bitn(hint, 2); 513 bool speculative = bitn(hint, 3); 514 515 if (uncontended && contended) 516 return op->emitOpError() << "the hints omp_sync_hint_uncontended and " 517 "omp_sync_hint_contended cannot be combined"; 518 if (nonspeculative && speculative) 519 return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and " 520 "omp_sync_hint_speculative cannot be combined."; 521 return success(); 522 } 523 524 enum ClauseType { 525 ifClause, 526 numThreadsClause, 527 privateClause, 528 firstprivateClause, 529 lastprivateClause, 530 sharedClause, 531 copyinClause, 532 allocateClause, 533 defaultClause, 534 procBindClause, 535 reductionClause, 536 nowaitClause, 537 linearClause, 538 scheduleClause, 539 collapseClause, 540 orderClause, 541 orderedClause, 542 memoryOrderClause, 543 hintClause, 544 COUNT 545 }; 546 547 //===----------------------------------------------------------------------===// 548 // Parser for Clause List 549 //===----------------------------------------------------------------------===// 550 551 /// Parse a clause attribute `(` $value `)`. 552 template <typename ClauseAttr> 553 static ParseResult parseClauseAttr(AsmParser &parser, OperationState &state, 554 StringRef attrName, StringRef name) { 555 using ClauseT = decltype(std::declval<ClauseAttr>().getValue()); 556 StringRef enumStr; 557 llvm::SMLoc loc = parser.getCurrentLocation(); 558 if (parser.parseLParen() || parser.parseKeyword(&enumStr) || 559 parser.parseRParen()) 560 return failure(); 561 if (Optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) { 562 auto attr = ClauseAttr::get(parser.getContext(), *enumValue); 563 state.addAttribute(attrName, attr); 564 return success(); 565 } 566 return parser.emitError(loc, "invalid ") << name << " kind"; 567 } 568 569 /// Parse a list of clauses. The clauses can appear in any order, but their 570 /// operand segment indices are in the same order that they are passed in the 571 /// `clauses` list. The operand segments are added over the prevSegments 572 573 /// clause-list ::= clause clause-list | empty 574 /// clause ::= if | num-threads | private | firstprivate | lastprivate | 575 /// shared | copyin | allocate | default | proc-bind | reduction | 576 /// nowait | linear | schedule | collapse | order | ordered | 577 /// inclusive 578 /// if ::= `if` `(` ssa-id-and-type `)` 579 /// num-threads ::= `num_threads` `(` ssa-id-and-type `)` 580 /// private ::= `private` operand-and-type-list 581 /// firstprivate ::= `firstprivate` operand-and-type-list 582 /// lastprivate ::= `lastprivate` operand-and-type-list 583 /// shared ::= `shared` operand-and-type-list 584 /// copyin ::= `copyin` operand-and-type-list 585 /// allocate ::= `allocate` `(` allocate-operand-list `)` 586 /// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`) 587 /// proc-bind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)` 588 /// reduction ::= `reduction` `(` reduction-entry-list `)` 589 /// nowait ::= `nowait` 590 /// linear ::= `linear` `(` linear-list `)` 591 /// schedule ::= `schedule` `(` sched-list `)` 592 /// collapse ::= `collapse` `(` ssa-id-and-type `)` 593 /// order ::= `order` `(` `concurrent` `)` 594 /// ordered ::= `ordered` `(` ssa-id-and-type `)` 595 /// inclusive ::= `inclusive` 596 /// 597 /// Note that each clause can only appear once in the clase-list. 598 static ParseResult parseClauses(OpAsmParser &parser, OperationState &result, 599 SmallVectorImpl<ClauseType> &clauses, 600 SmallVectorImpl<int> &segments) { 601 602 // Check done[clause] to see if it has been parsed already 603 llvm::BitVector done(ClauseType::COUNT, false); 604 605 // See pos[clause] to get position of clause in operand segments 606 SmallVector<int> pos(ClauseType::COUNT, -1); 607 608 // Stores the last parsed clause keyword 609 StringRef clauseKeyword; 610 StringRef opName = result.name.getStringRef(); 611 612 // Containers for storing operands, types and attributes for various clauses 613 std::pair<OpAsmParser::OperandType, Type> ifCond; 614 std::pair<OpAsmParser::OperandType, Type> numThreads; 615 616 SmallVector<OpAsmParser::OperandType> privates, firstprivates, lastprivates, 617 shareds, copyins; 618 SmallVector<Type> privateTypes, firstprivateTypes, lastprivateTypes, 619 sharedTypes, copyinTypes; 620 621 SmallVector<OpAsmParser::OperandType> allocates, allocators; 622 SmallVector<Type> allocateTypes, allocatorTypes; 623 624 SmallVector<SymbolRefAttr> reductionSymbols; 625 SmallVector<OpAsmParser::OperandType> reductionVars; 626 SmallVector<Type> reductionVarTypes; 627 628 SmallVector<OpAsmParser::OperandType> linears; 629 SmallVector<Type> linearTypes; 630 SmallVector<OpAsmParser::OperandType> linearSteps; 631 632 SmallString<8> schedule; 633 SmallVector<SmallString<12>> modifiers; 634 Optional<OpAsmParser::OperandType> scheduleChunkSize; 635 Type scheduleChunkType; 636 637 // Compute the position of clauses in operand segments 638 int currPos = 0; 639 for (ClauseType clause : clauses) { 640 641 // Skip the following clauses - they do not take any position in operand 642 // segments 643 if (clause == defaultClause || clause == procBindClause || 644 clause == nowaitClause || clause == collapseClause || 645 clause == orderClause || clause == orderedClause) 646 continue; 647 648 pos[clause] = currPos++; 649 650 // For the following clauses, two positions are reserved in the operand 651 // segments 652 if (clause == allocateClause || clause == linearClause) 653 currPos++; 654 } 655 656 SmallVector<int> clauseSegments(currPos); 657 658 // Helper function to check if a clause is allowed/repeated or not 659 auto checkAllowed = [&](ClauseType clause) -> ParseResult { 660 if (!llvm::is_contained(clauses, clause)) 661 return parser.emitError(parser.getCurrentLocation()) 662 << clauseKeyword << " is not a valid clause for the " << opName 663 << " operation"; 664 if (done[clause]) 665 return parser.emitError(parser.getCurrentLocation()) 666 << "at most one " << clauseKeyword << " clause can appear on the " 667 << opName << " operation"; 668 done[clause] = true; 669 return success(); 670 }; 671 672 while (succeeded(parser.parseOptionalKeyword(&clauseKeyword))) { 673 if (clauseKeyword == "if") { 674 if (checkAllowed(ifClause) || parser.parseLParen() || 675 parser.parseOperand(ifCond.first) || 676 parser.parseColonType(ifCond.second) || parser.parseRParen()) 677 return failure(); 678 clauseSegments[pos[ifClause]] = 1; 679 } else if (clauseKeyword == "num_threads") { 680 if (checkAllowed(numThreadsClause) || parser.parseLParen() || 681 parser.parseOperand(numThreads.first) || 682 parser.parseColonType(numThreads.second) || parser.parseRParen()) 683 return failure(); 684 clauseSegments[pos[numThreadsClause]] = 1; 685 } else if (clauseKeyword == "private") { 686 if (checkAllowed(privateClause) || 687 parseOperandAndTypeList(parser, privates, privateTypes)) 688 return failure(); 689 clauseSegments[pos[privateClause]] = privates.size(); 690 } else if (clauseKeyword == "firstprivate") { 691 if (checkAllowed(firstprivateClause) || 692 parseOperandAndTypeList(parser, firstprivates, firstprivateTypes)) 693 return failure(); 694 clauseSegments[pos[firstprivateClause]] = firstprivates.size(); 695 } else if (clauseKeyword == "lastprivate") { 696 if (checkAllowed(lastprivateClause) || 697 parseOperandAndTypeList(parser, lastprivates, lastprivateTypes)) 698 return failure(); 699 clauseSegments[pos[lastprivateClause]] = lastprivates.size(); 700 } else if (clauseKeyword == "shared") { 701 if (checkAllowed(sharedClause) || 702 parseOperandAndTypeList(parser, shareds, sharedTypes)) 703 return failure(); 704 clauseSegments[pos[sharedClause]] = shareds.size(); 705 } else if (clauseKeyword == "copyin") { 706 if (checkAllowed(copyinClause) || 707 parseOperandAndTypeList(parser, copyins, copyinTypes)) 708 return failure(); 709 clauseSegments[pos[copyinClause]] = copyins.size(); 710 } else if (clauseKeyword == "allocate") { 711 if (checkAllowed(allocateClause) || 712 parseAllocateAndAllocator(parser, allocates, allocateTypes, 713 allocators, allocatorTypes)) 714 return failure(); 715 clauseSegments[pos[allocateClause]] = allocates.size(); 716 clauseSegments[pos[allocateClause] + 1] = allocators.size(); 717 } else if (clauseKeyword == "default") { 718 StringRef defval; 719 llvm::SMLoc loc = parser.getCurrentLocation(); 720 if (checkAllowed(defaultClause) || parser.parseLParen() || 721 parser.parseKeyword(&defval) || parser.parseRParen()) 722 return failure(); 723 // The def prefix is required for the attribute as "private" is a keyword 724 // in C++. 725 if (Optional<ClauseDefault> def = 726 symbolizeClauseDefault(("def" + defval).str())) { 727 result.addAttribute("default_val", 728 ClauseDefaultAttr::get(parser.getContext(), *def)); 729 } else { 730 return parser.emitError(loc, "invalid default clause"); 731 } 732 } else if (clauseKeyword == "proc_bind") { 733 if (checkAllowed(procBindClause) || 734 parseClauseAttr<ClauseProcBindKindAttr>(parser, result, 735 "proc_bind_val", "proc bind")) 736 return failure(); 737 } else if (clauseKeyword == "reduction") { 738 if (checkAllowed(reductionClause) || 739 parseReductionVarList(parser, reductionSymbols, reductionVars, 740 reductionVarTypes)) 741 return failure(); 742 clauseSegments[pos[reductionClause]] = reductionVars.size(); 743 } else if (clauseKeyword == "nowait") { 744 if (checkAllowed(nowaitClause)) 745 return failure(); 746 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 747 result.addAttribute("nowait", attr); 748 } else if (clauseKeyword == "linear") { 749 if (checkAllowed(linearClause) || 750 parseLinearClause(parser, linears, linearTypes, linearSteps)) 751 return failure(); 752 clauseSegments[pos[linearClause]] = linears.size(); 753 clauseSegments[pos[linearClause] + 1] = linearSteps.size(); 754 } else if (clauseKeyword == "schedule") { 755 if (checkAllowed(scheduleClause) || 756 parseScheduleClause(parser, schedule, modifiers, scheduleChunkSize, 757 scheduleChunkType)) 758 return failure(); 759 if (scheduleChunkSize) { 760 clauseSegments[pos[scheduleClause]] = 1; 761 } 762 } else if (clauseKeyword == "collapse") { 763 auto type = parser.getBuilder().getI64Type(); 764 mlir::IntegerAttr attr; 765 if (checkAllowed(collapseClause) || parser.parseLParen() || 766 parser.parseAttribute(attr, type) || parser.parseRParen()) 767 return failure(); 768 result.addAttribute("collapse_val", attr); 769 } else if (clauseKeyword == "ordered") { 770 mlir::IntegerAttr attr; 771 if (checkAllowed(orderedClause)) 772 return failure(); 773 if (succeeded(parser.parseOptionalLParen())) { 774 auto type = parser.getBuilder().getI64Type(); 775 if (parser.parseAttribute(attr, type) || parser.parseRParen()) 776 return failure(); 777 } else { 778 // Use 0 to represent no ordered parameter was specified 779 attr = parser.getBuilder().getI64IntegerAttr(0); 780 } 781 result.addAttribute("ordered_val", attr); 782 } else if (clauseKeyword == "order") { 783 if (checkAllowed(orderClause) || 784 parseClauseAttr<ClauseOrderKindAttr>(parser, result, "order_val", 785 "order")) 786 return failure(); 787 } else if (clauseKeyword == "memory_order") { 788 if (checkAllowed(memoryOrderClause) || 789 parseClauseAttr<ClauseMemoryOrderKindAttr>( 790 parser, result, "memory_order", "memory order")) 791 return failure(); 792 } else if (clauseKeyword == "hint") { 793 IntegerAttr hint; 794 if (checkAllowed(hintClause) || 795 parseSynchronizationHint(parser, hint, false)) 796 return failure(); 797 result.addAttribute("hint", hint); 798 } else { 799 return parser.emitError(parser.getNameLoc()) 800 << clauseKeyword << " is not a valid clause"; 801 } 802 } 803 804 // Add if parameter. 805 if (done[ifClause] && clauseSegments[pos[ifClause]] && 806 failed( 807 parser.resolveOperand(ifCond.first, ifCond.second, result.operands))) 808 return failure(); 809 810 // Add num_threads parameter. 811 if (done[numThreadsClause] && clauseSegments[pos[numThreadsClause]] && 812 failed(parser.resolveOperand(numThreads.first, numThreads.second, 813 result.operands))) 814 return failure(); 815 816 // Add private parameters. 817 if (done[privateClause] && clauseSegments[pos[privateClause]] && 818 failed(parser.resolveOperands(privates, privateTypes, 819 privates[0].location, result.operands))) 820 return failure(); 821 822 // Add firstprivate parameters. 823 if (done[firstprivateClause] && clauseSegments[pos[firstprivateClause]] && 824 failed(parser.resolveOperands(firstprivates, firstprivateTypes, 825 firstprivates[0].location, 826 result.operands))) 827 return failure(); 828 829 // Add lastprivate parameters. 830 if (done[lastprivateClause] && clauseSegments[pos[lastprivateClause]] && 831 failed(parser.resolveOperands(lastprivates, lastprivateTypes, 832 lastprivates[0].location, result.operands))) 833 return failure(); 834 835 // Add shared parameters. 836 if (done[sharedClause] && clauseSegments[pos[sharedClause]] && 837 failed(parser.resolveOperands(shareds, sharedTypes, shareds[0].location, 838 result.operands))) 839 return failure(); 840 841 // Add copyin parameters. 842 if (done[copyinClause] && clauseSegments[pos[copyinClause]] && 843 failed(parser.resolveOperands(copyins, copyinTypes, copyins[0].location, 844 result.operands))) 845 return failure(); 846 847 // Add allocate parameters. 848 if (done[allocateClause] && clauseSegments[pos[allocateClause]] && 849 failed(parser.resolveOperands(allocates, allocateTypes, 850 allocates[0].location, result.operands))) 851 return failure(); 852 853 // Add allocator parameters. 854 if (done[allocateClause] && clauseSegments[pos[allocateClause] + 1] && 855 failed(parser.resolveOperands(allocators, allocatorTypes, 856 allocators[0].location, result.operands))) 857 return failure(); 858 859 // Add reduction parameters and symbols 860 if (done[reductionClause] && clauseSegments[pos[reductionClause]]) { 861 if (failed(parser.resolveOperands(reductionVars, reductionVarTypes, 862 parser.getNameLoc(), result.operands))) 863 return failure(); 864 865 SmallVector<Attribute> reductions(reductionSymbols.begin(), 866 reductionSymbols.end()); 867 result.addAttribute("reductions", 868 parser.getBuilder().getArrayAttr(reductions)); 869 } 870 871 // Add linear parameters 872 if (done[linearClause] && clauseSegments[pos[linearClause]]) { 873 auto linearStepType = parser.getBuilder().getI32Type(); 874 SmallVector<Type> linearStepTypes(linearSteps.size(), linearStepType); 875 if (failed(parser.resolveOperands(linears, linearTypes, linears[0].location, 876 result.operands)) || 877 failed(parser.resolveOperands(linearSteps, linearStepTypes, 878 linearSteps[0].location, 879 result.operands))) 880 return failure(); 881 } 882 883 // Add schedule parameters 884 if (done[scheduleClause] && !schedule.empty()) { 885 schedule[0] = llvm::toUpper(schedule[0]); 886 if (Optional<ClauseScheduleKind> sched = 887 symbolizeClauseScheduleKind(schedule)) { 888 auto attr = ClauseScheduleKindAttr::get(parser.getContext(), *sched); 889 result.addAttribute("schedule_val", attr); 890 } else { 891 return parser.emitError(parser.getCurrentLocation(), 892 "invalid schedule kind"); 893 } 894 if (!modifiers.empty()) { 895 llvm::SMLoc loc = parser.getCurrentLocation(); 896 if (Optional<ScheduleModifier> mod = 897 symbolizeScheduleModifier(modifiers[0])) { 898 result.addAttribute( 899 "schedule_modifier", 900 ScheduleModifierAttr::get(parser.getContext(), *mod)); 901 } else { 902 return parser.emitError(loc, "invalid schedule modifier"); 903 } 904 // Only SIMD attribute is allowed here! 905 if (modifiers.size() > 1) { 906 assert(symbolizeScheduleModifier(modifiers[1]) == 907 ScheduleModifier::simd); 908 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 909 result.addAttribute("simd_modifier", attr); 910 } 911 } 912 if (scheduleChunkSize) 913 parser.resolveOperand(*scheduleChunkSize, scheduleChunkType, 914 result.operands); 915 } 916 917 segments.insert(segments.end(), clauseSegments.begin(), clauseSegments.end()); 918 919 return success(); 920 } 921 922 /// Parses a parallel operation. 923 /// 924 /// operation ::= `omp.parallel` clause-list 925 /// clause-list ::= clause | clause clause-list 926 /// clause ::= if | num-threads | private | firstprivate | shared | copyin | 927 /// allocate | default | proc-bind 928 /// 929 static ParseResult parseParallelOp(OpAsmParser &parser, 930 OperationState &result) { 931 SmallVector<ClauseType> clauses = { 932 ifClause, numThreadsClause, privateClause, 933 firstprivateClause, sharedClause, copyinClause, 934 allocateClause, defaultClause, procBindClause}; 935 936 SmallVector<int> segments; 937 938 if (failed(parseClauses(parser, result, clauses, segments))) 939 return failure(); 940 941 result.addAttribute("operand_segment_sizes", 942 parser.getBuilder().getI32VectorAttr(segments)); 943 944 Region *body = result.addRegion(); 945 SmallVector<OpAsmParser::OperandType> regionArgs; 946 SmallVector<Type> regionArgTypes; 947 if (parser.parseRegion(*body, regionArgs, regionArgTypes)) 948 return failure(); 949 return success(); 950 } 951 952 //===----------------------------------------------------------------------===// 953 // Parser, printer and verifier for SectionsOp 954 //===----------------------------------------------------------------------===// 955 956 /// Parses an OpenMP Sections operation 957 /// 958 /// sections ::= `omp.sections` clause-list 959 /// clause-list ::= clause clause-list | empty 960 /// clause ::= private | firstprivate | lastprivate | reduction | allocate | 961 /// nowait 962 static ParseResult parseSectionsOp(OpAsmParser &parser, 963 OperationState &result) { 964 965 SmallVector<ClauseType> clauses = {privateClause, firstprivateClause, 966 lastprivateClause, reductionClause, 967 allocateClause, nowaitClause}; 968 969 SmallVector<int> segments; 970 971 if (failed(parseClauses(parser, result, clauses, segments))) 972 return failure(); 973 974 result.addAttribute("operand_segment_sizes", 975 parser.getBuilder().getI32VectorAttr(segments)); 976 977 // Now parse the body. 978 Region *body = result.addRegion(); 979 if (parser.parseRegion(*body)) 980 return failure(); 981 return success(); 982 } 983 984 static void printSectionsOp(OpAsmPrinter &p, SectionsOp op) { 985 p << " "; 986 printDataVars(p, op.private_vars(), "private"); 987 printDataVars(p, op.firstprivate_vars(), "firstprivate"); 988 printDataVars(p, op.lastprivate_vars(), "lastprivate"); 989 990 if (!op.reduction_vars().empty()) 991 printReductionVarList(p, op.reductions(), op.reduction_vars()); 992 993 if (!op.allocate_vars().empty()) 994 printAllocateAndAllocator(p, op.allocate_vars(), op.allocators_vars()); 995 996 if (op.nowait()) 997 p << "nowait"; 998 999 p << ' '; 1000 p.printRegion(op.region()); 1001 } 1002 1003 static LogicalResult verifySectionsOp(SectionsOp op) { 1004 1005 // A list item may not appear in more than one clause on the same directive, 1006 // except that it may be specified in both firstprivate and lastprivate 1007 // clauses. 1008 for (auto var : op.private_vars()) { 1009 if (llvm::is_contained(op.firstprivate_vars(), var)) 1010 return op.emitOpError() 1011 << "operand used in both private and firstprivate clauses"; 1012 if (llvm::is_contained(op.lastprivate_vars(), var)) 1013 return op.emitOpError() 1014 << "operand used in both private and lastprivate clauses"; 1015 } 1016 1017 if (op.allocate_vars().size() != op.allocators_vars().size()) 1018 return op.emitError( 1019 "expected equal sizes for allocate and allocator variables"); 1020 1021 for (auto &inst : *op.region().begin()) { 1022 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) 1023 op.emitOpError() 1024 << "expected omp.section op or terminator op inside region"; 1025 } 1026 1027 return verifyReductionVarList(op, op.reductions(), op.reduction_vars()); 1028 } 1029 1030 /// Parses an OpenMP Workshare Loop operation 1031 /// 1032 /// wsloop ::= `omp.wsloop` loop-control clause-list 1033 /// loop-control ::= `(` ssa-id-list `)` `:` type `=` loop-bounds 1034 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` inclusive? steps 1035 /// steps := `step` `(`ssa-id-list`)` 1036 /// clause-list ::= clause clause-list | empty 1037 /// clause ::= private | firstprivate | lastprivate | linear | schedule | 1038 // collapse | nowait | ordered | order | reduction 1039 static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) { 1040 1041 // Parse an opening `(` followed by induction variables followed by `)` 1042 SmallVector<OpAsmParser::OperandType> ivs; 1043 if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1, 1044 OpAsmParser::Delimiter::Paren)) 1045 return failure(); 1046 1047 int numIVs = static_cast<int>(ivs.size()); 1048 Type loopVarType; 1049 if (parser.parseColonType(loopVarType)) 1050 return failure(); 1051 1052 // Parse loop bounds. 1053 SmallVector<OpAsmParser::OperandType> lower; 1054 if (parser.parseEqual() || 1055 parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) || 1056 parser.resolveOperands(lower, loopVarType, result.operands)) 1057 return failure(); 1058 1059 SmallVector<OpAsmParser::OperandType> upper; 1060 if (parser.parseKeyword("to") || 1061 parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) || 1062 parser.resolveOperands(upper, loopVarType, result.operands)) 1063 return failure(); 1064 1065 if (succeeded(parser.parseOptionalKeyword("inclusive"))) { 1066 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 1067 result.addAttribute("inclusive", attr); 1068 } 1069 1070 // Parse step values. 1071 SmallVector<OpAsmParser::OperandType> steps; 1072 if (parser.parseKeyword("step") || 1073 parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) || 1074 parser.resolveOperands(steps, loopVarType, result.operands)) 1075 return failure(); 1076 1077 SmallVector<ClauseType> clauses = { 1078 privateClause, firstprivateClause, lastprivateClause, linearClause, 1079 reductionClause, collapseClause, orderClause, orderedClause, 1080 nowaitClause, scheduleClause}; 1081 SmallVector<int> segments{numIVs, numIVs, numIVs}; 1082 if (failed(parseClauses(parser, result, clauses, segments))) 1083 return failure(); 1084 1085 result.addAttribute("operand_segment_sizes", 1086 parser.getBuilder().getI32VectorAttr(segments)); 1087 1088 // Now parse the body. 1089 Region *body = result.addRegion(); 1090 SmallVector<Type> ivTypes(numIVs, loopVarType); 1091 SmallVector<OpAsmParser::OperandType> blockArgs(ivs); 1092 if (parser.parseRegion(*body, blockArgs, ivTypes)) 1093 return failure(); 1094 return success(); 1095 } 1096 1097 static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) { 1098 auto args = op.getRegion().front().getArguments(); 1099 p << " (" << args << ") : " << args[0].getType() << " = (" << op.lowerBound() 1100 << ") to (" << op.upperBound() << ") "; 1101 if (op.inclusive()) { 1102 p << "inclusive "; 1103 } 1104 p << "step (" << op.step() << ") "; 1105 1106 printDataVars(p, op.private_vars(), "private"); 1107 printDataVars(p, op.firstprivate_vars(), "firstprivate"); 1108 printDataVars(p, op.lastprivate_vars(), "lastprivate"); 1109 1110 if (!op.linear_vars().empty()) 1111 printLinearClause(p, op.linear_vars(), op.linear_step_vars()); 1112 1113 if (auto sched = op.schedule_val()) 1114 printScheduleClause(p, sched.getValue(), op.schedule_modifier(), 1115 op.simd_modifier(), op.schedule_chunk_var()); 1116 1117 if (auto collapse = op.collapse_val()) 1118 p << "collapse(" << collapse << ") "; 1119 1120 if (op.nowait()) 1121 p << "nowait "; 1122 1123 if (auto ordered = op.ordered_val()) 1124 p << "ordered(" << ordered << ") "; 1125 1126 if (auto order = op.order_val()) 1127 p << "order(" << stringifyClauseOrderKind(*order) << ") "; 1128 1129 if (!op.reduction_vars().empty()) 1130 printReductionVarList(p, op.reductions(), op.reduction_vars()); 1131 1132 p << ' '; 1133 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 1134 } 1135 1136 //===----------------------------------------------------------------------===// 1137 // ReductionOp 1138 //===----------------------------------------------------------------------===// 1139 1140 static ParseResult parseAtomicReductionRegion(OpAsmParser &parser, 1141 Region ®ion) { 1142 if (parser.parseOptionalKeyword("atomic")) 1143 return success(); 1144 return parser.parseRegion(region); 1145 } 1146 1147 static void printAtomicReductionRegion(OpAsmPrinter &printer, 1148 ReductionDeclareOp op, Region ®ion) { 1149 if (region.empty()) 1150 return; 1151 printer << "atomic "; 1152 printer.printRegion(region); 1153 } 1154 1155 static LogicalResult verifyReductionDeclareOp(ReductionDeclareOp op) { 1156 if (op.initializerRegion().empty()) 1157 return op.emitOpError() << "expects non-empty initializer region"; 1158 Block &initializerEntryBlock = op.initializerRegion().front(); 1159 if (initializerEntryBlock.getNumArguments() != 1 || 1160 initializerEntryBlock.getArgument(0).getType() != op.type()) { 1161 return op.emitOpError() << "expects initializer region with one argument " 1162 "of the reduction type"; 1163 } 1164 1165 for (YieldOp yieldOp : op.initializerRegion().getOps<YieldOp>()) { 1166 if (yieldOp.results().size() != 1 || 1167 yieldOp.results().getTypes()[0] != op.type()) 1168 return op.emitOpError() << "expects initializer region to yield a value " 1169 "of the reduction type"; 1170 } 1171 1172 if (op.reductionRegion().empty()) 1173 return op.emitOpError() << "expects non-empty reduction region"; 1174 Block &reductionEntryBlock = op.reductionRegion().front(); 1175 if (reductionEntryBlock.getNumArguments() != 2 || 1176 reductionEntryBlock.getArgumentTypes()[0] != 1177 reductionEntryBlock.getArgumentTypes()[1] || 1178 reductionEntryBlock.getArgumentTypes()[0] != op.type()) 1179 return op.emitOpError() << "expects reduction region with two arguments of " 1180 "the reduction type"; 1181 for (YieldOp yieldOp : op.reductionRegion().getOps<YieldOp>()) { 1182 if (yieldOp.results().size() != 1 || 1183 yieldOp.results().getTypes()[0] != op.type()) 1184 return op.emitOpError() << "expects reduction region to yield a value " 1185 "of the reduction type"; 1186 } 1187 1188 if (op.atomicReductionRegion().empty()) 1189 return success(); 1190 1191 Block &atomicReductionEntryBlock = op.atomicReductionRegion().front(); 1192 if (atomicReductionEntryBlock.getNumArguments() != 2 || 1193 atomicReductionEntryBlock.getArgumentTypes()[0] != 1194 atomicReductionEntryBlock.getArgumentTypes()[1]) 1195 return op.emitOpError() << "expects atomic reduction region with two " 1196 "arguments of the same type"; 1197 auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0] 1198 .dyn_cast<PointerLikeType>(); 1199 if (!ptrType || ptrType.getElementType() != op.type()) 1200 return op.emitOpError() << "expects atomic reduction region arguments to " 1201 "be accumulators containing the reduction type"; 1202 return success(); 1203 } 1204 1205 static LogicalResult verifyReductionOp(ReductionOp op) { 1206 // TODO: generalize this to an op interface when there is more than one op 1207 // that supports reductions. 1208 auto container = op->getParentOfType<WsLoopOp>(); 1209 for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i) 1210 if (container.reduction_vars()[i] == op.accumulator()) 1211 return success(); 1212 1213 return op.emitOpError() << "the accumulator is not used by the parent"; 1214 } 1215 1216 //===----------------------------------------------------------------------===// 1217 // WsLoopOp 1218 //===----------------------------------------------------------------------===// 1219 1220 void WsLoopOp::build(OpBuilder &builder, OperationState &state, 1221 ValueRange lowerBound, ValueRange upperBound, 1222 ValueRange step, ArrayRef<NamedAttribute> attributes) { 1223 build(builder, state, TypeRange(), lowerBound, upperBound, step, 1224 /*privateVars=*/ValueRange(), 1225 /*firstprivateVars=*/ValueRange(), /*lastprivate_vars=*/ValueRange(), 1226 /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(), 1227 /*reduction_vars=*/ValueRange(), /*schedule_val=*/nullptr, 1228 /*schedule_chunk_var=*/nullptr, /*collapse_val=*/nullptr, 1229 /*nowait=*/nullptr, /*ordered_val=*/nullptr, /*order_val=*/nullptr, 1230 /*inclusive=*/nullptr, /*buildBody=*/false); 1231 state.addAttributes(attributes); 1232 } 1233 1234 void WsLoopOp::build(OpBuilder &, OperationState &state, TypeRange resultTypes, 1235 ValueRange operands, ArrayRef<NamedAttribute> attributes) { 1236 state.addOperands(operands); 1237 state.addAttributes(attributes); 1238 (void)state.addRegion(); 1239 assert(resultTypes.empty() && "mismatched number of return types"); 1240 state.addTypes(resultTypes); 1241 } 1242 1243 void WsLoopOp::build(OpBuilder &builder, OperationState &result, 1244 TypeRange typeRange, ValueRange lowerBounds, 1245 ValueRange upperBounds, ValueRange steps, 1246 ValueRange privateVars, ValueRange firstprivateVars, 1247 ValueRange lastprivateVars, ValueRange linearVars, 1248 ValueRange linearStepVars, ValueRange reductionVars, 1249 StringAttr scheduleVal, Value scheduleChunkVar, 1250 IntegerAttr collapseVal, UnitAttr nowait, 1251 IntegerAttr orderedVal, StringAttr orderVal, 1252 UnitAttr inclusive, bool buildBody) { 1253 result.addOperands(lowerBounds); 1254 result.addOperands(upperBounds); 1255 result.addOperands(steps); 1256 result.addOperands(privateVars); 1257 result.addOperands(firstprivateVars); 1258 result.addOperands(linearVars); 1259 result.addOperands(linearStepVars); 1260 if (scheduleChunkVar) 1261 result.addOperands(scheduleChunkVar); 1262 1263 if (scheduleVal) 1264 result.addAttribute("schedule_val", scheduleVal); 1265 if (collapseVal) 1266 result.addAttribute("collapse_val", collapseVal); 1267 if (nowait) 1268 result.addAttribute("nowait", nowait); 1269 if (orderedVal) 1270 result.addAttribute("ordered_val", orderedVal); 1271 if (orderVal) 1272 result.addAttribute("order", orderVal); 1273 if (inclusive) 1274 result.addAttribute("inclusive", inclusive); 1275 result.addAttribute( 1276 WsLoopOp::getOperandSegmentSizeAttr(), 1277 builder.getI32VectorAttr( 1278 {static_cast<int32_t>(lowerBounds.size()), 1279 static_cast<int32_t>(upperBounds.size()), 1280 static_cast<int32_t>(steps.size()), 1281 static_cast<int32_t>(privateVars.size()), 1282 static_cast<int32_t>(firstprivateVars.size()), 1283 static_cast<int32_t>(lastprivateVars.size()), 1284 static_cast<int32_t>(linearVars.size()), 1285 static_cast<int32_t>(linearStepVars.size()), 1286 static_cast<int32_t>(reductionVars.size()), 1287 static_cast<int32_t>(scheduleChunkVar != nullptr ? 1 : 0)})); 1288 1289 Region *bodyRegion = result.addRegion(); 1290 if (buildBody) { 1291 OpBuilder::InsertionGuard guard(builder); 1292 unsigned numIVs = steps.size(); 1293 SmallVector<Type, 8> argTypes(numIVs, steps.getType().front()); 1294 SmallVector<Location, 8> argLocs(numIVs, result.location); 1295 builder.createBlock(bodyRegion, {}, argTypes, argLocs); 1296 } 1297 } 1298 1299 static LogicalResult verifyWsLoopOp(WsLoopOp op) { 1300 return verifyReductionVarList(op, op.reductions(), op.reduction_vars()); 1301 } 1302 1303 //===----------------------------------------------------------------------===// 1304 // Verifier for critical construct (2.17.1) 1305 //===----------------------------------------------------------------------===// 1306 1307 static LogicalResult verifyCriticalDeclareOp(CriticalDeclareOp op) { 1308 return verifySynchronizationHint(op, op.hint()); 1309 } 1310 1311 static LogicalResult verifyCriticalOp(CriticalOp op) { 1312 1313 if (op.nameAttr()) { 1314 auto symbolRef = op.nameAttr().cast<SymbolRefAttr>(); 1315 auto decl = 1316 SymbolTable::lookupNearestSymbolFrom<CriticalDeclareOp>(op, symbolRef); 1317 if (!decl) { 1318 return op.emitOpError() << "expected symbol reference " << symbolRef 1319 << " to point to a critical declaration"; 1320 } 1321 } 1322 1323 return success(); 1324 } 1325 1326 //===----------------------------------------------------------------------===// 1327 // Verifier for ordered construct 1328 //===----------------------------------------------------------------------===// 1329 1330 static LogicalResult verifyOrderedOp(OrderedOp op) { 1331 auto container = op->getParentOfType<WsLoopOp>(); 1332 if (!container || !container.ordered_valAttr() || 1333 container.ordered_valAttr().getInt() == 0) 1334 return op.emitOpError() << "ordered depend directive must be closely " 1335 << "nested inside a worksharing-loop with ordered " 1336 << "clause with parameter present"; 1337 1338 if (container.ordered_valAttr().getInt() != 1339 (int64_t)op.num_loops_val().getValue()) 1340 return op.emitOpError() << "number of variables in depend clause does not " 1341 << "match number of iteration variables in the " 1342 << "doacross loop"; 1343 1344 return success(); 1345 } 1346 1347 static LogicalResult verifyOrderedRegionOp(OrderedRegionOp op) { 1348 // TODO: The code generation for ordered simd directive is not supported yet. 1349 if (op.simd()) 1350 return failure(); 1351 1352 if (auto container = op->getParentOfType<WsLoopOp>()) { 1353 if (!container.ordered_valAttr() || 1354 container.ordered_valAttr().getInt() != 0) 1355 return op.emitOpError() << "ordered region must be closely nested inside " 1356 << "a worksharing-loop region with an ordered " 1357 << "clause without parameter present"; 1358 } 1359 1360 return success(); 1361 } 1362 1363 //===----------------------------------------------------------------------===// 1364 // AtomicReadOp 1365 //===----------------------------------------------------------------------===// 1366 1367 /// Parser for AtomicReadOp 1368 /// 1369 /// operation ::= `omp.atomic.read` atomic-clause-list address `->` result-type 1370 /// address ::= operand `:` type 1371 static ParseResult parseAtomicReadOp(OpAsmParser &parser, 1372 OperationState &result) { 1373 OpAsmParser::OperandType x, v; 1374 Type addressType; 1375 SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause}; 1376 SmallVector<int> segments; 1377 1378 if (parser.parseOperand(v) || parser.parseEqual() || parser.parseOperand(x) || 1379 parseClauses(parser, result, clauses, segments) || 1380 parser.parseColonType(addressType) || 1381 parser.resolveOperand(x, addressType, result.operands) || 1382 parser.resolveOperand(v, addressType, result.operands)) 1383 return failure(); 1384 1385 return success(); 1386 } 1387 1388 /// Printer for AtomicReadOp 1389 static void printAtomicReadOp(OpAsmPrinter &p, AtomicReadOp op) { 1390 p << " " << op.v() << " = " << op.x() << " "; 1391 if (auto mo = op.memory_order()) 1392 p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") "; 1393 if (op.hintAttr()) 1394 printSynchronizationHint(p << " ", op, op.hintAttr()); 1395 p << ": " << op.x().getType(); 1396 } 1397 1398 /// Verifier for AtomicReadOp 1399 static LogicalResult verifyAtomicReadOp(AtomicReadOp op) { 1400 if (auto mo = op.memory_order()) { 1401 if (*mo == ClauseMemoryOrderKind::acq_rel || 1402 *mo == ClauseMemoryOrderKind::release) { 1403 return op.emitError( 1404 "memory-order must not be acq_rel or release for atomic reads"); 1405 } 1406 } 1407 if (op.x() == op.v()) 1408 return op.emitError( 1409 "read and write must not be to the same location for atomic reads"); 1410 return verifySynchronizationHint(op, op.hint()); 1411 } 1412 1413 //===----------------------------------------------------------------------===// 1414 // AtomicWriteOp 1415 //===----------------------------------------------------------------------===// 1416 1417 /// Parser for AtomicWriteOp 1418 /// 1419 /// operation ::= `omp.atomic.write` atomic-clause-list operands 1420 /// operands ::= address `,` value 1421 /// address ::= operand `:` type 1422 /// value ::= operand `:` type 1423 static ParseResult parseAtomicWriteOp(OpAsmParser &parser, 1424 OperationState &result) { 1425 OpAsmParser::OperandType address, value; 1426 Type addrType, valueType; 1427 SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause}; 1428 SmallVector<int> segments; 1429 1430 if (parser.parseOperand(address) || parser.parseEqual() || 1431 parser.parseOperand(value) || 1432 parseClauses(parser, result, clauses, segments) || 1433 parser.parseColonType(addrType) || parser.parseComma() || 1434 parser.parseType(valueType) || 1435 parser.resolveOperand(address, addrType, result.operands) || 1436 parser.resolveOperand(value, valueType, result.operands)) 1437 return failure(); 1438 return success(); 1439 } 1440 1441 /// Printer for AtomicWriteOp 1442 static void printAtomicWriteOp(OpAsmPrinter &p, AtomicWriteOp op) { 1443 p << " " << op.address() << " = " << op.value() << " "; 1444 if (auto mo = op.memory_order()) 1445 p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") "; 1446 if (op.hintAttr()) 1447 printSynchronizationHint(p, op, op.hintAttr()); 1448 p << ": " << op.address().getType() << ", " << op.value().getType(); 1449 } 1450 1451 /// Verifier for AtomicWriteOp 1452 static LogicalResult verifyAtomicWriteOp(AtomicWriteOp op) { 1453 if (auto mo = op.memory_order()) { 1454 if (*mo == ClauseMemoryOrderKind::acq_rel || 1455 *mo == ClauseMemoryOrderKind::acquire) { 1456 return op.emitError( 1457 "memory-order must not be acq_rel or acquire for atomic writes"); 1458 } 1459 } 1460 return verifySynchronizationHint(op, op.hint()); 1461 } 1462 1463 //===----------------------------------------------------------------------===// 1464 // AtomicUpdateOp 1465 //===----------------------------------------------------------------------===// 1466 1467 /// Parser for AtomicUpdateOp 1468 /// 1469 /// operation ::= `omp.atomic.update` atomic-clause-list region 1470 static ParseResult parseAtomicUpdateOp(OpAsmParser &parser, 1471 OperationState &result) { 1472 SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause}; 1473 SmallVector<int> segments; 1474 OpAsmParser::OperandType x, y, z; 1475 Type xType, exprType; 1476 StringRef binOp; 1477 1478 // x = y `op` z : xtype, exprtype 1479 if (parser.parseOperand(x) || parser.parseEqual() || parser.parseOperand(y) || 1480 parser.parseKeyword(&binOp) || parser.parseOperand(z) || 1481 parseClauses(parser, result, clauses, segments) || parser.parseColon() || 1482 parser.parseType(xType) || parser.parseComma() || 1483 parser.parseType(exprType) || 1484 parser.resolveOperand(x, xType, result.operands)) { 1485 return failure(); 1486 } 1487 1488 auto binOpEnum = AtomicBinOpKindToEnum(binOp.upper()); 1489 if (!binOpEnum) 1490 return parser.emitError(parser.getNameLoc()) 1491 << "invalid atomic bin op in atomic update\n"; 1492 auto attr = 1493 parser.getBuilder().getI64IntegerAttr((int64_t)binOpEnum.getValue()); 1494 result.addAttribute("binop", attr); 1495 1496 OpAsmParser::OperandType expr; 1497 if (x.name == y.name && x.number == y.number) { 1498 expr = z; 1499 result.addAttribute("isXBinopExpr", parser.getBuilder().getUnitAttr()); 1500 } else if (x.name == z.name && x.number == z.number) { 1501 expr = y; 1502 } else { 1503 return parser.emitError(parser.getNameLoc()) 1504 << "atomic update variable " << x.name 1505 << " not found in the RHS of the assignment statement in an" 1506 " atomic.update operation"; 1507 } 1508 return parser.resolveOperand(expr, exprType, result.operands); 1509 } 1510 1511 /// Printer for AtomicUpdateOp 1512 static void printAtomicUpdateOp(OpAsmPrinter &p, AtomicUpdateOp op) { 1513 p << " " << op.x() << " = "; 1514 Value y, z; 1515 if (op.isXBinopExpr()) { 1516 y = op.x(); 1517 z = op.expr(); 1518 } else { 1519 y = op.expr(); 1520 z = op.x(); 1521 } 1522 p << y << " " << AtomicBinOpKindToString(op.binop()).lower() << " " << z 1523 << " "; 1524 if (auto mo = op.memory_order()) 1525 p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") "; 1526 if (op.hintAttr()) 1527 printSynchronizationHint(p, op, op.hintAttr()); 1528 p << ": " << op.x().getType() << ", " << op.expr().getType(); 1529 } 1530 1531 /// Verifier for AtomicUpdateOp 1532 static LogicalResult verifyAtomicUpdateOp(AtomicUpdateOp op) { 1533 if (auto mo = op.memory_order()) { 1534 if (*mo == ClauseMemoryOrderKind::acq_rel || 1535 *mo == ClauseMemoryOrderKind::acquire) { 1536 return op.emitError( 1537 "memory-order must not be acq_rel or acquire for atomic updates"); 1538 } 1539 } 1540 return success(); 1541 } 1542 1543 #define GET_ATTRDEF_CLASSES 1544 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" 1545 1546 #define GET_OP_CLASSES 1547 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 1548