1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the OpenMP dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 15 #include "mlir/IR/Attributes.h" 16 #include "mlir/IR/DialectImplementation.h" 17 #include "mlir/IR/OpImplementation.h" 18 #include "mlir/IR/OperationSupport.h" 19 20 #include "llvm/ADT/BitVector.h" 21 #include "llvm/ADT/SmallString.h" 22 #include "llvm/ADT/StringExtras.h" 23 #include "llvm/ADT/StringRef.h" 24 #include "llvm/ADT/StringSwitch.h" 25 #include "llvm/ADT/TypeSwitch.h" 26 #include <cstddef> 27 28 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc" 29 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc" 30 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc" 31 32 using namespace mlir; 33 using namespace mlir::omp; 34 35 namespace { 36 /// Model for pointer-like types that already provide a `getElementType` method. 37 template <typename T> 38 struct PointerLikeModel 39 : public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> { 40 Type getElementType(Type pointer) const { 41 return pointer.cast<T>().getElementType(); 42 } 43 }; 44 } // namespace 45 46 void OpenMPDialect::initialize() { 47 addOperations< 48 #define GET_OP_LIST 49 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 50 >(); 51 addAttributes< 52 #define GET_ATTRDEF_LIST 53 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" 54 >(); 55 56 LLVM::LLVMPointerType::attachInterface< 57 PointerLikeModel<LLVM::LLVMPointerType>>(*getContext()); 58 MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext()); 59 } 60 61 //===----------------------------------------------------------------------===// 62 // ParallelOp 63 //===----------------------------------------------------------------------===// 64 65 void ParallelOp::build(OpBuilder &builder, OperationState &state, 66 ArrayRef<NamedAttribute> attributes) { 67 ParallelOp::build( 68 builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr, 69 /*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(), 70 /*proc_bind_val=*/nullptr); 71 state.addAttributes(attributes); 72 } 73 74 //===----------------------------------------------------------------------===// 75 // Parser and printer for Allocate Clause 76 //===----------------------------------------------------------------------===// 77 78 /// Parse an allocate clause with allocators and a list of operands with types. 79 /// 80 /// allocate-operand-list :: = allocate-operand | 81 /// allocator-operand `,` allocate-operand-list 82 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type 83 /// ssa-id-and-type ::= ssa-id `:` type 84 static ParseResult parseAllocateAndAllocator( 85 OpAsmParser &parser, 86 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate, 87 SmallVectorImpl<Type> &typesAllocate, 88 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator, 89 SmallVectorImpl<Type> &typesAllocator) { 90 91 return parser.parseCommaSeparatedList([&]() -> ParseResult { 92 OpAsmParser::OperandType operand; 93 Type type; 94 if (parser.parseOperand(operand) || parser.parseColonType(type)) 95 return failure(); 96 operandsAllocator.push_back(operand); 97 typesAllocator.push_back(type); 98 if (parser.parseArrow()) 99 return failure(); 100 if (parser.parseOperand(operand) || parser.parseColonType(type)) 101 return failure(); 102 103 operandsAllocate.push_back(operand); 104 typesAllocate.push_back(type); 105 return success(); 106 }); 107 } 108 109 /// Print allocate clause 110 static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, 111 OperandRange varsAllocate, 112 TypeRange typesAllocate, 113 OperandRange varsAllocator, 114 TypeRange typesAllocator) { 115 for (unsigned i = 0; i < varsAllocate.size(); ++i) { 116 std::string separator = i == varsAllocate.size() - 1 ? "" : ", "; 117 p << varsAllocator[i] << " : " << typesAllocator[i] << " -> "; 118 p << varsAllocate[i] << " : " << typesAllocate[i] << separator; 119 } 120 } 121 122 /// Parse a clause attribute (StringEnumAttr) 123 template <typename ClauseAttr> 124 static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) { 125 using ClauseT = decltype(std::declval<ClauseAttr>().getValue()); 126 StringRef enumStr; 127 SMLoc loc = parser.getCurrentLocation(); 128 if (parser.parseKeyword(&enumStr)) 129 return failure(); 130 if (Optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) { 131 attr = ClauseAttr::get(parser.getContext(), *enumValue); 132 return success(); 133 } 134 return parser.emitError(loc, "invalid clause value: '") << enumStr << "'"; 135 } 136 137 template <typename ClauseAttr> 138 void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) { 139 p << stringifyEnum(attr.getValue()); 140 } 141 142 //===----------------------------------------------------------------------===// 143 // Parser and printer for Procbind Clause 144 //===----------------------------------------------------------------------===// 145 146 ParseResult parseProcBindKind(OpAsmParser &parser, 147 omp::ClauseProcBindKindAttr &procBindAttr) { 148 StringRef procBindStr; 149 if (parser.parseKeyword(&procBindStr)) 150 return failure(); 151 if (auto procBindVal = symbolizeClauseProcBindKind(procBindStr)) { 152 procBindAttr = 153 ClauseProcBindKindAttr::get(parser.getContext(), *procBindVal); 154 return success(); 155 } 156 return failure(); 157 } 158 159 void printProcBindKind(OpAsmPrinter &p, Operation *op, 160 omp::ClauseProcBindKindAttr procBindAttr) { 161 p << stringifyClauseProcBindKind(procBindAttr.getValue()); 162 } 163 164 LogicalResult ParallelOp::verify() { 165 if (allocate_vars().size() != allocators_vars().size()) 166 return emitError( 167 "expected equal sizes for allocate and allocator variables"); 168 return success(); 169 } 170 171 //===----------------------------------------------------------------------===// 172 // Parser and printer for Linear Clause 173 //===----------------------------------------------------------------------===// 174 175 /// linear ::= `linear` `(` linear-list `)` 176 /// linear-list := linear-val | linear-val linear-list 177 /// linear-val := ssa-id-and-type `=` ssa-id-and-type 178 static ParseResult 179 parseLinearClause(OpAsmParser &parser, 180 SmallVectorImpl<OpAsmParser::OperandType> &vars, 181 SmallVectorImpl<Type> &types, 182 SmallVectorImpl<OpAsmParser::OperandType> &stepVars) { 183 if (parser.parseLParen()) 184 return failure(); 185 186 do { 187 OpAsmParser::OperandType var; 188 Type type; 189 OpAsmParser::OperandType stepVar; 190 if (parser.parseOperand(var) || parser.parseEqual() || 191 parser.parseOperand(stepVar) || parser.parseColonType(type)) 192 return failure(); 193 194 vars.push_back(var); 195 types.push_back(type); 196 stepVars.push_back(stepVar); 197 } while (succeeded(parser.parseOptionalComma())); 198 199 if (parser.parseRParen()) 200 return failure(); 201 202 return success(); 203 } 204 205 /// Print Linear Clause 206 static void printLinearClause(OpAsmPrinter &p, OperandRange linearVars, 207 OperandRange linearStepVars) { 208 size_t linearVarsSize = linearVars.size(); 209 p << "linear("; 210 for (unsigned i = 0; i < linearVarsSize; ++i) { 211 std::string separator = i == linearVarsSize - 1 ? ") " : ", "; 212 p << linearVars[i]; 213 if (linearStepVars.size() > i) 214 p << " = " << linearStepVars[i]; 215 p << " : " << linearVars[i].getType() << separator; 216 } 217 } 218 219 //===----------------------------------------------------------------------===// 220 // Parser, printer and verifier for Schedule Clause 221 //===----------------------------------------------------------------------===// 222 223 static ParseResult 224 verifyScheduleModifiers(OpAsmParser &parser, 225 SmallVectorImpl<SmallString<12>> &modifiers) { 226 if (modifiers.size() > 2) 227 return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)"; 228 for (const auto &mod : modifiers) { 229 // Translate the string. If it has no value, then it was not a valid 230 // modifier! 231 auto symbol = symbolizeScheduleModifier(mod); 232 if (!symbol.hasValue()) 233 return parser.emitError(parser.getNameLoc()) 234 << " unknown modifier type: " << mod; 235 } 236 237 // If we have one modifier that is "simd", then stick a "none" modiifer in 238 // index 0. 239 if (modifiers.size() == 1) { 240 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) { 241 modifiers.push_back(modifiers[0]); 242 modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none); 243 } 244 } else if (modifiers.size() == 2) { 245 // If there are two modifier: 246 // First modifier should not be simd, second one should be simd 247 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd || 248 symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd) 249 return parser.emitError(parser.getNameLoc()) 250 << " incorrect modifier order"; 251 } 252 return success(); 253 } 254 255 /// schedule ::= `schedule` `(` sched-list `)` 256 /// sched-list ::= sched-val | sched-val sched-list | 257 /// sched-val `,` sched-modifier 258 /// sched-val ::= sched-with-chunk | sched-wo-chunk 259 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)? 260 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided` 261 /// sched-wo-chunk ::= `auto` | `runtime` 262 /// sched-modifier ::= sched-mod-val | sched-mod-val `,` sched-mod-val 263 /// sched-mod-val ::= `monotonic` | `nonmonotonic` | `simd` | `none` 264 static ParseResult 265 parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule, 266 SmallVectorImpl<SmallString<12>> &modifiers, 267 Optional<OpAsmParser::OperandType> &chunkSize, 268 Type &chunkType) { 269 if (parser.parseLParen()) 270 return failure(); 271 272 StringRef keyword; 273 if (parser.parseKeyword(&keyword)) 274 return failure(); 275 276 schedule = keyword; 277 if (keyword == "static" || keyword == "dynamic" || keyword == "guided") { 278 if (succeeded(parser.parseOptionalEqual())) { 279 chunkSize = OpAsmParser::OperandType{}; 280 if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType)) 281 return failure(); 282 } else { 283 chunkSize = llvm::NoneType::None; 284 } 285 } else if (keyword == "auto" || keyword == "runtime") { 286 chunkSize = llvm::NoneType::None; 287 } else { 288 return parser.emitError(parser.getNameLoc()) << " expected schedule kind"; 289 } 290 291 // If there is a comma, we have one or more modifiers.. 292 while (succeeded(parser.parseOptionalComma())) { 293 StringRef mod; 294 if (parser.parseKeyword(&mod)) 295 return failure(); 296 modifiers.push_back(mod); 297 } 298 299 if (parser.parseRParen()) 300 return failure(); 301 302 if (verifyScheduleModifiers(parser, modifiers)) 303 return failure(); 304 305 return success(); 306 } 307 308 /// Print schedule clause 309 static void printScheduleClause(OpAsmPrinter &p, ClauseScheduleKind sched, 310 Optional<ScheduleModifier> modifier, bool simd, 311 Value scheduleChunkVar) { 312 p << "schedule(" << stringifyClauseScheduleKind(sched).lower(); 313 if (scheduleChunkVar) 314 p << " = " << scheduleChunkVar << " : " << scheduleChunkVar.getType(); 315 if (modifier) 316 p << ", " << stringifyScheduleModifier(*modifier); 317 if (simd) 318 p << ", simd"; 319 p << ") "; 320 } 321 322 //===----------------------------------------------------------------------===// 323 // Parser, printer and verifier for ReductionVarList 324 //===----------------------------------------------------------------------===// 325 326 /// reduction-entry-list ::= reduction-entry 327 /// | reduction-entry-list `,` reduction-entry 328 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type 329 static ParseResult parseReductionVarList( 330 OpAsmParser &parser, SmallVectorImpl<OpAsmParser::OperandType> &operands, 331 SmallVectorImpl<Type> &types, ArrayAttr &redcuctionSymbols) { 332 SmallVector<SymbolRefAttr> reductionVec; 333 do { 334 if (parser.parseAttribute(reductionVec.emplace_back()) || 335 parser.parseArrow() || parser.parseOperand(operands.emplace_back()) || 336 parser.parseColonType(types.emplace_back())) 337 return failure(); 338 } while (succeeded(parser.parseOptionalComma())); 339 SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end()); 340 redcuctionSymbols = ArrayAttr::get(parser.getContext(), reductions); 341 return success(); 342 } 343 344 /// Print Reduction clause 345 static void printReductionVarList(OpAsmPrinter &p, Operation *op, 346 OperandRange reductionVars, 347 TypeRange reductionTypes, 348 Optional<ArrayAttr> reductions) { 349 for (unsigned i = 0, e = reductions->size(); i < e; ++i) { 350 if (i != 0) 351 p << ", "; 352 p << (*reductions)[i] << " -> " << reductionVars[i] << " : " 353 << reductionVars[i].getType(); 354 } 355 } 356 357 /// Verifies Reduction Clause 358 static LogicalResult verifyReductionVarList(Operation *op, 359 Optional<ArrayAttr> reductions, 360 OperandRange reductionVars) { 361 if (!reductionVars.empty()) { 362 if (!reductions || reductions->size() != reductionVars.size()) 363 return op->emitOpError() 364 << "expected as many reduction symbol references " 365 "as reduction variables"; 366 } else { 367 if (reductions) 368 return op->emitOpError() << "unexpected reduction symbol references"; 369 return success(); 370 } 371 372 DenseSet<Value> accumulators; 373 for (auto args : llvm::zip(reductionVars, *reductions)) { 374 Value accum = std::get<0>(args); 375 376 if (!accumulators.insert(accum).second) 377 return op->emitOpError() << "accumulator variable used more than once"; 378 379 Type varType = accum.getType().cast<PointerLikeType>(); 380 auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>(); 381 auto decl = 382 SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef); 383 if (!decl) 384 return op->emitOpError() << "expected symbol reference " << symbolRef 385 << " to point to a reduction declaration"; 386 387 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType) 388 return op->emitOpError() 389 << "expected accumulator (" << varType 390 << ") to be the same type as reduction declaration (" 391 << decl.getAccumulatorType() << ")"; 392 } 393 394 return success(); 395 } 396 397 //===----------------------------------------------------------------------===// 398 // Parser, printer and verifier for Synchronization Hint (2.17.12) 399 //===----------------------------------------------------------------------===// 400 401 /// Parses a Synchronization Hint clause. The value of hint is an integer 402 /// which is a combination of different hints from `omp_sync_hint_t`. 403 /// 404 /// hint-clause = `hint` `(` hint-value `)` 405 static ParseResult parseSynchronizationHint(OpAsmParser &parser, 406 IntegerAttr &hintAttr) { 407 StringRef hintKeyword; 408 int64_t hint = 0; 409 do { 410 if (failed(parser.parseKeyword(&hintKeyword))) 411 return failure(); 412 if (hintKeyword == "uncontended") 413 hint |= 1; 414 else if (hintKeyword == "contended") 415 hint |= 2; 416 else if (hintKeyword == "nonspeculative") 417 hint |= 4; 418 else if (hintKeyword == "speculative") 419 hint |= 8; 420 else 421 return parser.emitError(parser.getCurrentLocation()) 422 << hintKeyword << " is not a valid hint"; 423 } while (succeeded(parser.parseOptionalComma())); 424 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint); 425 return success(); 426 } 427 428 /// Prints a Synchronization Hint clause 429 static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, 430 IntegerAttr hintAttr) { 431 int64_t hint = hintAttr.getInt(); 432 433 if (hint == 0) 434 return; 435 436 // Helper function to get n-th bit from the right end of `value` 437 auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; 438 439 bool uncontended = bitn(hint, 0); 440 bool contended = bitn(hint, 1); 441 bool nonspeculative = bitn(hint, 2); 442 bool speculative = bitn(hint, 3); 443 444 SmallVector<StringRef> hints; 445 if (uncontended) 446 hints.push_back("uncontended"); 447 if (contended) 448 hints.push_back("contended"); 449 if (nonspeculative) 450 hints.push_back("nonspeculative"); 451 if (speculative) 452 hints.push_back("speculative"); 453 454 llvm::interleaveComma(hints, p); 455 } 456 457 /// Verifies a synchronization hint clause 458 static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) { 459 460 // Helper function to get n-th bit from the right end of `value` 461 auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; 462 463 bool uncontended = bitn(hint, 0); 464 bool contended = bitn(hint, 1); 465 bool nonspeculative = bitn(hint, 2); 466 bool speculative = bitn(hint, 3); 467 468 if (uncontended && contended) 469 return op->emitOpError() << "the hints omp_sync_hint_uncontended and " 470 "omp_sync_hint_contended cannot be combined"; 471 if (nonspeculative && speculative) 472 return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and " 473 "omp_sync_hint_speculative cannot be combined."; 474 return success(); 475 } 476 477 enum ClauseType { 478 allocateClause, 479 reductionClause, 480 nowaitClause, 481 linearClause, 482 scheduleClause, 483 collapseClause, 484 orderClause, 485 orderedClause, 486 COUNT 487 }; 488 489 //===----------------------------------------------------------------------===// 490 // Parser for Clause List 491 //===----------------------------------------------------------------------===// 492 493 /// Parse a list of clauses. The clauses can appear in any order, but their 494 /// operand segment indices are in the same order that they are passed in the 495 /// `clauses` list. The operand segments are added over the prevSegments 496 497 /// clause-list ::= clause clause-list | empty 498 /// clause ::= allocate | reduction | nowait | linear | schedule | collapse 499 /// | order | ordered 500 /// allocate ::= `allocate` `(` allocate-operand-list `)` 501 /// reduction ::= `reduction` `(` reduction-entry-list `)` 502 /// nowait ::= `nowait` 503 /// linear ::= `linear` `(` linear-list `)` 504 /// schedule ::= `schedule` `(` sched-list `)` 505 /// collapse ::= `collapse` `(` ssa-id-and-type `)` 506 /// order ::= `order` `(` `concurrent` `)` 507 /// ordered ::= `ordered` `(` ssa-id-and-type `)` 508 /// 509 /// Note that each clause can only appear once in the clase-list. 510 static ParseResult parseClauses(OpAsmParser &parser, OperationState &result, 511 SmallVectorImpl<ClauseType> &clauses, 512 SmallVectorImpl<int> &segments) { 513 514 // Check done[clause] to see if it has been parsed already 515 BitVector done(ClauseType::COUNT, false); 516 517 // See pos[clause] to get position of clause in operand segments 518 SmallVector<int> pos(ClauseType::COUNT, -1); 519 520 // Stores the last parsed clause keyword 521 StringRef clauseKeyword; 522 StringRef opName = result.name.getStringRef(); 523 524 // Containers for storing operands, types and attributes for various clauses 525 SmallVector<OpAsmParser::OperandType> allocates, allocators; 526 SmallVector<Type> allocateTypes, allocatorTypes; 527 528 ArrayAttr reductions; 529 SmallVector<OpAsmParser::OperandType> reductionVars; 530 SmallVector<Type> reductionVarTypes; 531 532 SmallVector<OpAsmParser::OperandType> linears; 533 SmallVector<Type> linearTypes; 534 SmallVector<OpAsmParser::OperandType> linearSteps; 535 536 SmallString<8> schedule; 537 SmallVector<SmallString<12>> modifiers; 538 Optional<OpAsmParser::OperandType> scheduleChunkSize; 539 Type scheduleChunkType; 540 541 // Compute the position of clauses in operand segments 542 int currPos = 0; 543 for (ClauseType clause : clauses) { 544 545 // Skip the following clauses - they do not take any position in operand 546 // segments 547 if (clause == nowaitClause || clause == collapseClause || 548 clause == orderClause || clause == orderedClause) 549 continue; 550 551 pos[clause] = currPos++; 552 553 // For the following clauses, two positions are reserved in the operand 554 // segments 555 if (clause == allocateClause || clause == linearClause) 556 currPos++; 557 } 558 559 SmallVector<int> clauseSegments(currPos); 560 561 // Helper function to check if a clause is allowed/repeated or not 562 auto checkAllowed = [&](ClauseType clause) -> ParseResult { 563 if (!llvm::is_contained(clauses, clause)) 564 return parser.emitError(parser.getCurrentLocation()) 565 << clauseKeyword << " is not a valid clause for the " << opName 566 << " operation"; 567 if (done[clause]) 568 return parser.emitError(parser.getCurrentLocation()) 569 << "at most one " << clauseKeyword << " clause can appear on the " 570 << opName << " operation"; 571 done[clause] = true; 572 return success(); 573 }; 574 575 while (succeeded(parser.parseOptionalKeyword(&clauseKeyword))) { 576 if (clauseKeyword == "allocate") { 577 if (checkAllowed(allocateClause) || parser.parseLParen() || 578 parseAllocateAndAllocator(parser, allocates, allocateTypes, 579 allocators, allocatorTypes) || 580 parser.parseRParen()) 581 return failure(); 582 clauseSegments[pos[allocateClause]] = allocates.size(); 583 clauseSegments[pos[allocateClause] + 1] = allocators.size(); 584 } else if (clauseKeyword == "reduction") { 585 if (checkAllowed(reductionClause) || parser.parseLParen() || 586 parseReductionVarList(parser, reductionVars, reductionVarTypes, 587 reductions) || 588 parser.parseRParen()) 589 return failure(); 590 clauseSegments[pos[reductionClause]] = reductionVars.size(); 591 } else if (clauseKeyword == "nowait") { 592 if (checkAllowed(nowaitClause)) 593 return failure(); 594 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 595 result.addAttribute("nowait", attr); 596 } else if (clauseKeyword == "linear") { 597 if (checkAllowed(linearClause) || 598 parseLinearClause(parser, linears, linearTypes, linearSteps)) 599 return failure(); 600 clauseSegments[pos[linearClause]] = linears.size(); 601 clauseSegments[pos[linearClause] + 1] = linearSteps.size(); 602 } else if (clauseKeyword == "schedule") { 603 if (checkAllowed(scheduleClause) || 604 parseScheduleClause(parser, schedule, modifiers, scheduleChunkSize, 605 scheduleChunkType)) 606 return failure(); 607 if (scheduleChunkSize) { 608 clauseSegments[pos[scheduleClause]] = 1; 609 } 610 } else if (clauseKeyword == "collapse") { 611 auto type = parser.getBuilder().getI64Type(); 612 mlir::IntegerAttr attr; 613 if (checkAllowed(collapseClause) || parser.parseLParen() || 614 parser.parseAttribute(attr, type) || parser.parseRParen()) 615 return failure(); 616 result.addAttribute("collapse_val", attr); 617 } else if (clauseKeyword == "ordered") { 618 mlir::IntegerAttr attr; 619 if (checkAllowed(orderedClause)) 620 return failure(); 621 if (succeeded(parser.parseOptionalLParen())) { 622 auto type = parser.getBuilder().getI64Type(); 623 if (parser.parseAttribute(attr, type) || parser.parseRParen()) 624 return failure(); 625 } else { 626 // Use 0 to represent no ordered parameter was specified 627 attr = parser.getBuilder().getI64IntegerAttr(0); 628 } 629 result.addAttribute("ordered_val", attr); 630 } else if (clauseKeyword == "order") { 631 ClauseOrderKindAttr order; 632 if (checkAllowed(orderClause) || parser.parseLParen() || 633 parseClauseAttr<ClauseOrderKindAttr>(parser, order) || 634 parser.parseRParen()) 635 return failure(); 636 result.addAttribute("order_val", order); 637 } else { 638 return parser.emitError(parser.getNameLoc()) 639 << clauseKeyword << " is not a valid clause"; 640 } 641 } 642 643 // Add allocate parameters. 644 if (done[allocateClause] && clauseSegments[pos[allocateClause]] && 645 failed(parser.resolveOperands(allocates, allocateTypes, 646 allocates[0].location, result.operands))) 647 return failure(); 648 649 // Add allocator parameters. 650 if (done[allocateClause] && clauseSegments[pos[allocateClause] + 1] && 651 failed(parser.resolveOperands(allocators, allocatorTypes, 652 allocators[0].location, result.operands))) 653 return failure(); 654 655 // Add reduction parameters and symbols 656 if (done[reductionClause] && clauseSegments[pos[reductionClause]]) { 657 if (failed(parser.resolveOperands(reductionVars, reductionVarTypes, 658 parser.getNameLoc(), result.operands))) 659 return failure(); 660 result.addAttribute("reductions", reductions); 661 } 662 663 // Add linear parameters 664 if (done[linearClause] && clauseSegments[pos[linearClause]]) { 665 auto linearStepType = parser.getBuilder().getI32Type(); 666 SmallVector<Type> linearStepTypes(linearSteps.size(), linearStepType); 667 if (failed(parser.resolveOperands(linears, linearTypes, linears[0].location, 668 result.operands)) || 669 failed(parser.resolveOperands(linearSteps, linearStepTypes, 670 linearSteps[0].location, 671 result.operands))) 672 return failure(); 673 } 674 675 // Add schedule parameters 676 if (done[scheduleClause] && !schedule.empty()) { 677 schedule[0] = llvm::toUpper(schedule[0]); 678 if (Optional<ClauseScheduleKind> sched = 679 symbolizeClauseScheduleKind(schedule)) { 680 auto attr = ClauseScheduleKindAttr::get(parser.getContext(), *sched); 681 result.addAttribute("schedule_val", attr); 682 } else { 683 return parser.emitError(parser.getCurrentLocation(), 684 "invalid schedule kind"); 685 } 686 if (!modifiers.empty()) { 687 SMLoc loc = parser.getCurrentLocation(); 688 if (Optional<ScheduleModifier> mod = 689 symbolizeScheduleModifier(modifiers[0])) { 690 result.addAttribute( 691 "schedule_modifier", 692 ScheduleModifierAttr::get(parser.getContext(), *mod)); 693 } else { 694 return parser.emitError(loc, "invalid schedule modifier"); 695 } 696 // Only SIMD attribute is allowed here! 697 if (modifiers.size() > 1) { 698 assert(symbolizeScheduleModifier(modifiers[1]) == 699 ScheduleModifier::simd); 700 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 701 result.addAttribute("simd_modifier", attr); 702 } 703 } 704 if (scheduleChunkSize) 705 parser.resolveOperand(*scheduleChunkSize, scheduleChunkType, 706 result.operands); 707 } 708 709 segments.insert(segments.end(), clauseSegments.begin(), clauseSegments.end()); 710 711 return success(); 712 } 713 714 //===----------------------------------------------------------------------===// 715 // Verifier for SectionsOp 716 //===----------------------------------------------------------------------===// 717 718 LogicalResult SectionsOp::verify() { 719 if (allocate_vars().size() != allocators_vars().size()) 720 return emitError( 721 "expected equal sizes for allocate and allocator variables"); 722 723 for (auto &inst : *region().begin()) { 724 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) { 725 return emitOpError() 726 << "expected omp.section op or terminator op inside region"; 727 } 728 } 729 730 return verifyReductionVarList(*this, reductions(), reduction_vars()); 731 } 732 733 /// Parses an OpenMP Workshare Loop operation 734 /// 735 /// wsloop ::= `omp.wsloop` loop-control clause-list 736 /// loop-control ::= `(` ssa-id-list `)` `:` type `=` loop-bounds 737 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` inclusive? steps 738 /// steps := `step` `(`ssa-id-list`)` 739 /// clause-list ::= clause clause-list | empty 740 /// clause ::= linear | schedule | collapse | nowait | ordered | order 741 /// | reduction 742 ParseResult WsLoopOp::parse(OpAsmParser &parser, OperationState &result) { 743 // Parse an opening `(` followed by induction variables followed by `)` 744 SmallVector<OpAsmParser::OperandType> ivs; 745 if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1, 746 OpAsmParser::Delimiter::Paren)) 747 return failure(); 748 749 int numIVs = static_cast<int>(ivs.size()); 750 Type loopVarType; 751 if (parser.parseColonType(loopVarType)) 752 return failure(); 753 754 // Parse loop bounds. 755 SmallVector<OpAsmParser::OperandType> lower; 756 if (parser.parseEqual() || 757 parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) || 758 parser.resolveOperands(lower, loopVarType, result.operands)) 759 return failure(); 760 761 SmallVector<OpAsmParser::OperandType> upper; 762 if (parser.parseKeyword("to") || 763 parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) || 764 parser.resolveOperands(upper, loopVarType, result.operands)) 765 return failure(); 766 767 if (succeeded(parser.parseOptionalKeyword("inclusive"))) { 768 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 769 result.addAttribute("inclusive", attr); 770 } 771 772 // Parse step values. 773 SmallVector<OpAsmParser::OperandType> steps; 774 if (parser.parseKeyword("step") || 775 parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) || 776 parser.resolveOperands(steps, loopVarType, result.operands)) 777 return failure(); 778 779 SmallVector<ClauseType> clauses = { 780 linearClause, reductionClause, collapseClause, orderClause, 781 orderedClause, nowaitClause, scheduleClause}; 782 SmallVector<int> segments{numIVs, numIVs, numIVs}; 783 if (failed(parseClauses(parser, result, clauses, segments))) 784 return failure(); 785 786 result.addAttribute("operand_segment_sizes", 787 parser.getBuilder().getI32VectorAttr(segments)); 788 789 // Now parse the body. 790 Region *body = result.addRegion(); 791 SmallVector<Type> ivTypes(numIVs, loopVarType); 792 SmallVector<OpAsmParser::OperandType> blockArgs(ivs); 793 if (parser.parseRegion(*body, blockArgs, ivTypes)) 794 return failure(); 795 return success(); 796 } 797 798 void WsLoopOp::print(OpAsmPrinter &p) { 799 auto args = getRegion().front().getArguments(); 800 p << " (" << args << ") : " << args[0].getType() << " = (" << lowerBound() 801 << ") to (" << upperBound() << ") "; 802 if (inclusive()) { 803 p << "inclusive "; 804 } 805 p << "step (" << step() << ") "; 806 807 if (!linear_vars().empty()) 808 printLinearClause(p, linear_vars(), linear_step_vars()); 809 810 if (auto sched = schedule_val()) 811 printScheduleClause(p, sched.getValue(), schedule_modifier(), 812 simd_modifier(), schedule_chunk_var()); 813 814 if (auto collapse = collapse_val()) 815 p << "collapse(" << collapse << ") "; 816 817 if (nowait()) 818 p << "nowait "; 819 820 if (auto ordered = ordered_val()) 821 p << "ordered(" << ordered << ") "; 822 823 if (auto order = order_val()) 824 p << "order(" << stringifyClauseOrderKind(*order) << ") "; 825 826 if (!reduction_vars().empty()) { 827 printReductionVarList(p << "reduction(", *this, reduction_vars(), 828 reduction_vars().getTypes(), reductions()); 829 p << ")"; 830 } 831 832 p << ' '; 833 p.printRegion(region(), /*printEntryBlockArgs=*/false); 834 } 835 836 //===----------------------------------------------------------------------===// 837 // ReductionOp 838 //===----------------------------------------------------------------------===// 839 840 static ParseResult parseAtomicReductionRegion(OpAsmParser &parser, 841 Region ®ion) { 842 if (parser.parseOptionalKeyword("atomic")) 843 return success(); 844 return parser.parseRegion(region); 845 } 846 847 static void printAtomicReductionRegion(OpAsmPrinter &printer, 848 ReductionDeclareOp op, Region ®ion) { 849 if (region.empty()) 850 return; 851 printer << "atomic "; 852 printer.printRegion(region); 853 } 854 855 LogicalResult ReductionDeclareOp::verify() { 856 if (initializerRegion().empty()) 857 return emitOpError() << "expects non-empty initializer region"; 858 Block &initializerEntryBlock = initializerRegion().front(); 859 if (initializerEntryBlock.getNumArguments() != 1 || 860 initializerEntryBlock.getArgument(0).getType() != type()) { 861 return emitOpError() << "expects initializer region with one argument " 862 "of the reduction type"; 863 } 864 865 for (YieldOp yieldOp : initializerRegion().getOps<YieldOp>()) { 866 if (yieldOp.results().size() != 1 || 867 yieldOp.results().getTypes()[0] != type()) 868 return emitOpError() << "expects initializer region to yield a value " 869 "of the reduction type"; 870 } 871 872 if (reductionRegion().empty()) 873 return emitOpError() << "expects non-empty reduction region"; 874 Block &reductionEntryBlock = reductionRegion().front(); 875 if (reductionEntryBlock.getNumArguments() != 2 || 876 reductionEntryBlock.getArgumentTypes()[0] != 877 reductionEntryBlock.getArgumentTypes()[1] || 878 reductionEntryBlock.getArgumentTypes()[0] != type()) 879 return emitOpError() << "expects reduction region with two arguments of " 880 "the reduction type"; 881 for (YieldOp yieldOp : reductionRegion().getOps<YieldOp>()) { 882 if (yieldOp.results().size() != 1 || 883 yieldOp.results().getTypes()[0] != type()) 884 return emitOpError() << "expects reduction region to yield a value " 885 "of the reduction type"; 886 } 887 888 if (atomicReductionRegion().empty()) 889 return success(); 890 891 Block &atomicReductionEntryBlock = atomicReductionRegion().front(); 892 if (atomicReductionEntryBlock.getNumArguments() != 2 || 893 atomicReductionEntryBlock.getArgumentTypes()[0] != 894 atomicReductionEntryBlock.getArgumentTypes()[1]) 895 return emitOpError() << "expects atomic reduction region with two " 896 "arguments of the same type"; 897 auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0] 898 .dyn_cast<PointerLikeType>(); 899 if (!ptrType || ptrType.getElementType() != type()) 900 return emitOpError() << "expects atomic reduction region arguments to " 901 "be accumulators containing the reduction type"; 902 return success(); 903 } 904 905 LogicalResult ReductionOp::verify() { 906 // TODO: generalize this to an op interface when there is more than one op 907 // that supports reductions. 908 auto container = (*this)->getParentOfType<WsLoopOp>(); 909 for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i) 910 if (container.reduction_vars()[i] == accumulator()) 911 return success(); 912 913 return emitOpError() << "the accumulator is not used by the parent"; 914 } 915 916 //===----------------------------------------------------------------------===// 917 // WsLoopOp 918 //===----------------------------------------------------------------------===// 919 920 void WsLoopOp::build(OpBuilder &builder, OperationState &state, 921 ValueRange lowerBound, ValueRange upperBound, 922 ValueRange step, ArrayRef<NamedAttribute> attributes) { 923 build(builder, state, lowerBound, upperBound, step, 924 /*linear_vars=*/ValueRange(), 925 /*linear_step_vars=*/ValueRange(), /*reduction_vars=*/ValueRange(), 926 /*reductions=*/nullptr, /*schedule_val=*/nullptr, 927 /*schedule_chunk_var=*/nullptr, /*schedule_modifier=*/nullptr, 928 /*simd_modifier=*/false, /*collapse_val=*/nullptr, /*nowait=*/false, 929 /*ordered_val=*/nullptr, /*order_val=*/nullptr, /*inclusive=*/false); 930 state.addAttributes(attributes); 931 } 932 933 LogicalResult WsLoopOp::verify() { 934 return verifyReductionVarList(*this, reductions(), reduction_vars()); 935 } 936 937 //===----------------------------------------------------------------------===// 938 // Verifier for critical construct (2.17.1) 939 //===----------------------------------------------------------------------===// 940 941 LogicalResult CriticalDeclareOp::verify() { 942 return verifySynchronizationHint(*this, hint_val()); 943 } 944 945 LogicalResult CriticalOp::verify() { 946 if (nameAttr()) { 947 SymbolRefAttr symbolRef = nameAttr(); 948 auto decl = SymbolTable::lookupNearestSymbolFrom<CriticalDeclareOp>( 949 *this, symbolRef); 950 if (!decl) { 951 return emitOpError() << "expected symbol reference " << symbolRef 952 << " to point to a critical declaration"; 953 } 954 } 955 956 return success(); 957 } 958 959 //===----------------------------------------------------------------------===// 960 // Verifier for ordered construct 961 //===----------------------------------------------------------------------===// 962 963 LogicalResult OrderedOp::verify() { 964 auto container = (*this)->getParentOfType<WsLoopOp>(); 965 if (!container || !container.ordered_valAttr() || 966 container.ordered_valAttr().getInt() == 0) 967 return emitOpError() << "ordered depend directive must be closely " 968 << "nested inside a worksharing-loop with ordered " 969 << "clause with parameter present"; 970 971 if (container.ordered_valAttr().getInt() != 972 (int64_t)num_loops_val().getValue()) 973 return emitOpError() << "number of variables in depend clause does not " 974 << "match number of iteration variables in the " 975 << "doacross loop"; 976 977 return success(); 978 } 979 980 LogicalResult OrderedRegionOp::verify() { 981 // TODO: The code generation for ordered simd directive is not supported yet. 982 if (simd()) 983 return failure(); 984 985 if (auto container = (*this)->getParentOfType<WsLoopOp>()) { 986 if (!container.ordered_valAttr() || 987 container.ordered_valAttr().getInt() != 0) 988 return emitOpError() << "ordered region must be closely nested inside " 989 << "a worksharing-loop region with an ordered " 990 << "clause without parameter present"; 991 } 992 993 return success(); 994 } 995 996 //===----------------------------------------------------------------------===// 997 // Verifier for AtomicReadOp 998 //===----------------------------------------------------------------------===// 999 1000 LogicalResult AtomicReadOp::verify() { 1001 if (auto mo = memory_order_val()) { 1002 if (*mo == ClauseMemoryOrderKind::acq_rel || 1003 *mo == ClauseMemoryOrderKind::release) { 1004 return emitError( 1005 "memory-order must not be acq_rel or release for atomic reads"); 1006 } 1007 } 1008 if (x() == v()) 1009 return emitError( 1010 "read and write must not be to the same location for atomic reads"); 1011 return verifySynchronizationHint(*this, hint_val()); 1012 } 1013 1014 //===----------------------------------------------------------------------===// 1015 // Verifier for AtomicWriteOp 1016 //===----------------------------------------------------------------------===// 1017 1018 LogicalResult AtomicWriteOp::verify() { 1019 if (auto mo = memory_order_val()) { 1020 if (*mo == ClauseMemoryOrderKind::acq_rel || 1021 *mo == ClauseMemoryOrderKind::acquire) { 1022 return emitError( 1023 "memory-order must not be acq_rel or acquire for atomic writes"); 1024 } 1025 } 1026 return verifySynchronizationHint(*this, hint_val()); 1027 } 1028 1029 //===----------------------------------------------------------------------===// 1030 // Verifier for AtomicUpdateOp 1031 //===----------------------------------------------------------------------===// 1032 1033 LogicalResult AtomicUpdateOp::verify() { 1034 if (auto mo = memory_order_val()) { 1035 if (*mo == ClauseMemoryOrderKind::acq_rel || 1036 *mo == ClauseMemoryOrderKind::acquire) { 1037 return emitError( 1038 "memory-order must not be acq_rel or acquire for atomic updates"); 1039 } 1040 } 1041 1042 if (region().getNumArguments() != 1) 1043 return emitError("the region must accept exactly one argument"); 1044 1045 if (x().getType().cast<PointerLikeType>().getElementType() != 1046 region().getArgument(0).getType()) { 1047 return emitError("the type of the operand must be a pointer type whose " 1048 "element type is the same as that of the region argument"); 1049 } 1050 1051 YieldOp yieldOp = *region().getOps<YieldOp>().begin(); 1052 1053 if (yieldOp.results().size() != 1) 1054 return emitError("only updated value must be returned"); 1055 if (yieldOp.results().front().getType() != region().getArgument(0).getType()) 1056 return emitError("input and yielded value must have the same type"); 1057 return success(); 1058 } 1059 1060 //===----------------------------------------------------------------------===// 1061 // Verifier for AtomicCaptureOp 1062 //===----------------------------------------------------------------------===// 1063 1064 LogicalResult AtomicCaptureOp::verify() { 1065 Block::OpListType &ops = region().front().getOperations(); 1066 if (ops.size() != 3) 1067 return emitError() 1068 << "expected three operations in omp.atomic.capture region (one " 1069 "terminator, and two atomic ops)"; 1070 auto &firstOp = ops.front(); 1071 auto &secondOp = *ops.getNextNode(firstOp); 1072 auto firstReadStmt = dyn_cast<AtomicReadOp>(firstOp); 1073 auto firstUpdateStmt = dyn_cast<AtomicUpdateOp>(firstOp); 1074 auto secondReadStmt = dyn_cast<AtomicReadOp>(secondOp); 1075 auto secondUpdateStmt = dyn_cast<AtomicUpdateOp>(secondOp); 1076 auto secondWriteStmt = dyn_cast<AtomicWriteOp>(secondOp); 1077 1078 if (!((firstUpdateStmt && secondReadStmt) || 1079 (firstReadStmt && secondUpdateStmt) || 1080 (firstReadStmt && secondWriteStmt))) 1081 return ops.front().emitError() 1082 << "invalid sequence of operations in the capture region"; 1083 if (firstUpdateStmt && secondReadStmt && 1084 firstUpdateStmt.x() != secondReadStmt.x()) 1085 return firstUpdateStmt.emitError() 1086 << "updated variable in omp.atomic.update must be captured in " 1087 "second operation"; 1088 if (firstReadStmt && secondUpdateStmt && 1089 firstReadStmt.x() != secondUpdateStmt.x()) 1090 return firstReadStmt.emitError() 1091 << "captured variable in omp.atomic.read must be updated in second " 1092 "operation"; 1093 if (firstReadStmt && secondWriteStmt && 1094 firstReadStmt.x() != secondWriteStmt.address()) 1095 return firstReadStmt.emitError() 1096 << "captured variable in omp.atomic.read must be updated in " 1097 "second operation"; 1098 return success(); 1099 } 1100 1101 #define GET_ATTRDEF_CLASSES 1102 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" 1103 1104 #define GET_OP_CLASSES 1105 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 1106