1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the OpenMP dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 15 #include "mlir/IR/Attributes.h" 16 #include "mlir/IR/DialectImplementation.h" 17 #include "mlir/IR/OpImplementation.h" 18 #include "mlir/IR/OperationSupport.h" 19 20 #include "llvm/ADT/BitVector.h" 21 #include "llvm/ADT/SmallString.h" 22 #include "llvm/ADT/StringExtras.h" 23 #include "llvm/ADT/StringRef.h" 24 #include "llvm/ADT/StringSwitch.h" 25 #include "llvm/ADT/TypeSwitch.h" 26 #include <cstddef> 27 28 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc" 29 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc" 30 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc" 31 32 using namespace mlir; 33 using namespace mlir::omp; 34 35 namespace { 36 /// Model for pointer-like types that already provide a `getElementType` method. 37 template <typename T> 38 struct PointerLikeModel 39 : public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> { 40 Type getElementType(Type pointer) const { 41 return pointer.cast<T>().getElementType(); 42 } 43 }; 44 } // namespace 45 46 void OpenMPDialect::initialize() { 47 addOperations< 48 #define GET_OP_LIST 49 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 50 >(); 51 addAttributes< 52 #define GET_ATTRDEF_LIST 53 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" 54 >(); 55 56 LLVM::LLVMPointerType::attachInterface< 57 PointerLikeModel<LLVM::LLVMPointerType>>(*getContext()); 58 MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext()); 59 } 60 61 //===----------------------------------------------------------------------===// 62 // ParallelOp 63 //===----------------------------------------------------------------------===// 64 65 void ParallelOp::build(OpBuilder &builder, OperationState &state, 66 ArrayRef<NamedAttribute> attributes) { 67 ParallelOp::build( 68 builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr, 69 /*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(), 70 /*proc_bind_val=*/nullptr); 71 state.addAttributes(attributes); 72 } 73 74 //===----------------------------------------------------------------------===// 75 // Parser and printer for Allocate Clause 76 //===----------------------------------------------------------------------===// 77 78 /// Parse an allocate clause with allocators and a list of operands with types. 79 /// 80 /// allocate-operand-list :: = allocate-operand | 81 /// allocator-operand `,` allocate-operand-list 82 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type 83 /// ssa-id-and-type ::= ssa-id `:` type 84 static ParseResult parseAllocateAndAllocator( 85 OpAsmParser &parser, 86 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate, 87 SmallVectorImpl<Type> &typesAllocate, 88 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator, 89 SmallVectorImpl<Type> &typesAllocator) { 90 91 return parser.parseCommaSeparatedList([&]() -> ParseResult { 92 OpAsmParser::OperandType operand; 93 Type type; 94 if (parser.parseOperand(operand) || parser.parseColonType(type)) 95 return failure(); 96 operandsAllocator.push_back(operand); 97 typesAllocator.push_back(type); 98 if (parser.parseArrow()) 99 return failure(); 100 if (parser.parseOperand(operand) || parser.parseColonType(type)) 101 return failure(); 102 103 operandsAllocate.push_back(operand); 104 typesAllocate.push_back(type); 105 return success(); 106 }); 107 } 108 109 /// Print allocate clause 110 static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, 111 OperandRange varsAllocate, 112 TypeRange typesAllocate, 113 OperandRange varsAllocator, 114 TypeRange typesAllocator) { 115 for (unsigned i = 0; i < varsAllocate.size(); ++i) { 116 std::string separator = i == varsAllocate.size() - 1 ? "" : ", "; 117 p << varsAllocator[i] << " : " << typesAllocator[i] << " -> "; 118 p << varsAllocate[i] << " : " << typesAllocate[i] << separator; 119 } 120 } 121 122 /// Parse a clause attribute (StringEnumAttr) 123 template <typename ClauseAttr> 124 static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) { 125 using ClauseT = decltype(std::declval<ClauseAttr>().getValue()); 126 StringRef enumStr; 127 SMLoc loc = parser.getCurrentLocation(); 128 if (parser.parseKeyword(&enumStr)) 129 return failure(); 130 if (Optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) { 131 attr = ClauseAttr::get(parser.getContext(), *enumValue); 132 return success(); 133 } 134 return parser.emitError(loc, "invalid clause value: '") << enumStr << "'"; 135 } 136 137 template <typename ClauseAttr> 138 void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) { 139 p << stringifyEnum(attr.getValue()); 140 } 141 142 //===----------------------------------------------------------------------===// 143 // Parser and printer for Procbind Clause 144 //===----------------------------------------------------------------------===// 145 146 ParseResult parseProcBindKind(OpAsmParser &parser, 147 omp::ClauseProcBindKindAttr &procBindAttr) { 148 StringRef procBindStr; 149 if (parser.parseKeyword(&procBindStr)) 150 return failure(); 151 if (auto procBindVal = symbolizeClauseProcBindKind(procBindStr)) { 152 procBindAttr = 153 ClauseProcBindKindAttr::get(parser.getContext(), *procBindVal); 154 return success(); 155 } 156 return failure(); 157 } 158 159 void printProcBindKind(OpAsmPrinter &p, Operation *op, 160 omp::ClauseProcBindKindAttr procBindAttr) { 161 p << stringifyClauseProcBindKind(procBindAttr.getValue()); 162 } 163 164 LogicalResult ParallelOp::verify() { 165 if (allocate_vars().size() != allocators_vars().size()) 166 return emitError( 167 "expected equal sizes for allocate and allocator variables"); 168 return success(); 169 } 170 171 //===----------------------------------------------------------------------===// 172 // Parser and printer for Linear Clause 173 //===----------------------------------------------------------------------===// 174 175 /// linear ::= `linear` `(` linear-list `)` 176 /// linear-list := linear-val | linear-val linear-list 177 /// linear-val := ssa-id-and-type `=` ssa-id-and-type 178 static ParseResult 179 parseLinearClause(OpAsmParser &parser, 180 SmallVectorImpl<OpAsmParser::OperandType> &vars, 181 SmallVectorImpl<Type> &types, 182 SmallVectorImpl<OpAsmParser::OperandType> &stepVars) { 183 if (parser.parseLParen()) 184 return failure(); 185 186 do { 187 OpAsmParser::OperandType var; 188 Type type; 189 OpAsmParser::OperandType stepVar; 190 if (parser.parseOperand(var) || parser.parseEqual() || 191 parser.parseOperand(stepVar) || parser.parseColonType(type)) 192 return failure(); 193 194 vars.push_back(var); 195 types.push_back(type); 196 stepVars.push_back(stepVar); 197 } while (succeeded(parser.parseOptionalComma())); 198 199 if (parser.parseRParen()) 200 return failure(); 201 202 return success(); 203 } 204 205 /// Print Linear Clause 206 static void printLinearClause(OpAsmPrinter &p, OperandRange linearVars, 207 OperandRange linearStepVars) { 208 size_t linearVarsSize = linearVars.size(); 209 p << "linear("; 210 for (unsigned i = 0; i < linearVarsSize; ++i) { 211 std::string separator = i == linearVarsSize - 1 ? ") " : ", "; 212 p << linearVars[i]; 213 if (linearStepVars.size() > i) 214 p << " = " << linearStepVars[i]; 215 p << " : " << linearVars[i].getType() << separator; 216 } 217 } 218 219 //===----------------------------------------------------------------------===// 220 // Parser, printer and verifier for Schedule Clause 221 //===----------------------------------------------------------------------===// 222 223 static ParseResult 224 verifyScheduleModifiers(OpAsmParser &parser, 225 SmallVectorImpl<SmallString<12>> &modifiers) { 226 if (modifiers.size() > 2) 227 return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)"; 228 for (const auto &mod : modifiers) { 229 // Translate the string. If it has no value, then it was not a valid 230 // modifier! 231 auto symbol = symbolizeScheduleModifier(mod); 232 if (!symbol.hasValue()) 233 return parser.emitError(parser.getNameLoc()) 234 << " unknown modifier type: " << mod; 235 } 236 237 // If we have one modifier that is "simd", then stick a "none" modiifer in 238 // index 0. 239 if (modifiers.size() == 1) { 240 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) { 241 modifiers.push_back(modifiers[0]); 242 modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none); 243 } 244 } else if (modifiers.size() == 2) { 245 // If there are two modifier: 246 // First modifier should not be simd, second one should be simd 247 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd || 248 symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd) 249 return parser.emitError(parser.getNameLoc()) 250 << " incorrect modifier order"; 251 } 252 return success(); 253 } 254 255 /// schedule ::= `schedule` `(` sched-list `)` 256 /// sched-list ::= sched-val | sched-val sched-list | 257 /// sched-val `,` sched-modifier 258 /// sched-val ::= sched-with-chunk | sched-wo-chunk 259 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)? 260 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided` 261 /// sched-wo-chunk ::= `auto` | `runtime` 262 /// sched-modifier ::= sched-mod-val | sched-mod-val `,` sched-mod-val 263 /// sched-mod-val ::= `monotonic` | `nonmonotonic` | `simd` | `none` 264 static ParseResult 265 parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule, 266 SmallVectorImpl<SmallString<12>> &modifiers, 267 Optional<OpAsmParser::OperandType> &chunkSize, 268 Type &chunkType) { 269 if (parser.parseLParen()) 270 return failure(); 271 272 StringRef keyword; 273 if (parser.parseKeyword(&keyword)) 274 return failure(); 275 276 schedule = keyword; 277 if (keyword == "static" || keyword == "dynamic" || keyword == "guided") { 278 if (succeeded(parser.parseOptionalEqual())) { 279 chunkSize = OpAsmParser::OperandType{}; 280 if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType)) 281 return failure(); 282 } else { 283 chunkSize = llvm::NoneType::None; 284 } 285 } else if (keyword == "auto" || keyword == "runtime") { 286 chunkSize = llvm::NoneType::None; 287 } else { 288 return parser.emitError(parser.getNameLoc()) << " expected schedule kind"; 289 } 290 291 // If there is a comma, we have one or more modifiers.. 292 while (succeeded(parser.parseOptionalComma())) { 293 StringRef mod; 294 if (parser.parseKeyword(&mod)) 295 return failure(); 296 modifiers.push_back(mod); 297 } 298 299 if (parser.parseRParen()) 300 return failure(); 301 302 if (verifyScheduleModifiers(parser, modifiers)) 303 return failure(); 304 305 return success(); 306 } 307 308 /// Print schedule clause 309 static void printScheduleClause(OpAsmPrinter &p, ClauseScheduleKind sched, 310 Optional<ScheduleModifier> modifier, bool simd, 311 Value scheduleChunkVar) { 312 p << "schedule(" << stringifyClauseScheduleKind(sched).lower(); 313 if (scheduleChunkVar) 314 p << " = " << scheduleChunkVar << " : " << scheduleChunkVar.getType(); 315 if (modifier) 316 p << ", " << stringifyScheduleModifier(*modifier); 317 if (simd) 318 p << ", simd"; 319 p << ") "; 320 } 321 322 //===----------------------------------------------------------------------===// 323 // Parser, printer and verifier for ReductionVarList 324 //===----------------------------------------------------------------------===// 325 326 /// reduction-entry-list ::= reduction-entry 327 /// | reduction-entry-list `,` reduction-entry 328 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type 329 static ParseResult parseReductionVarList( 330 OpAsmParser &parser, SmallVectorImpl<OpAsmParser::OperandType> &operands, 331 SmallVectorImpl<Type> &types, ArrayAttr &redcuctionSymbols) { 332 SmallVector<SymbolRefAttr> reductionVec; 333 do { 334 if (parser.parseAttribute(reductionVec.emplace_back()) || 335 parser.parseArrow() || parser.parseOperand(operands.emplace_back()) || 336 parser.parseColonType(types.emplace_back())) 337 return failure(); 338 } while (succeeded(parser.parseOptionalComma())); 339 SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end()); 340 redcuctionSymbols = ArrayAttr::get(parser.getContext(), reductions); 341 return success(); 342 } 343 344 /// Print Reduction clause 345 static void printReductionVarList(OpAsmPrinter &p, Operation *op, 346 OperandRange reductionVars, 347 TypeRange reductionTypes, 348 Optional<ArrayAttr> reductions) { 349 for (unsigned i = 0, e = reductions->size(); i < e; ++i) { 350 if (i != 0) 351 p << ", "; 352 p << (*reductions)[i] << " -> " << reductionVars[i] << " : " 353 << reductionVars[i].getType(); 354 } 355 } 356 357 /// Verifies Reduction Clause 358 static LogicalResult verifyReductionVarList(Operation *op, 359 Optional<ArrayAttr> reductions, 360 OperandRange reductionVars) { 361 if (!reductionVars.empty()) { 362 if (!reductions || reductions->size() != reductionVars.size()) 363 return op->emitOpError() 364 << "expected as many reduction symbol references " 365 "as reduction variables"; 366 } else { 367 if (reductions) 368 return op->emitOpError() << "unexpected reduction symbol references"; 369 return success(); 370 } 371 372 // TODO: The followings should be done in 373 // SymbolUserOpInterface::verifySymbolUses. 374 DenseSet<Value> accumulators; 375 for (auto args : llvm::zip(reductionVars, *reductions)) { 376 Value accum = std::get<0>(args); 377 378 if (!accumulators.insert(accum).second) 379 return op->emitOpError() << "accumulator variable used more than once"; 380 381 Type varType = accum.getType().cast<PointerLikeType>(); 382 auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>(); 383 auto decl = 384 SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef); 385 if (!decl) 386 return op->emitOpError() << "expected symbol reference " << symbolRef 387 << " to point to a reduction declaration"; 388 389 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType) 390 return op->emitOpError() 391 << "expected accumulator (" << varType 392 << ") to be the same type as reduction declaration (" 393 << decl.getAccumulatorType() << ")"; 394 } 395 396 return success(); 397 } 398 399 //===----------------------------------------------------------------------===// 400 // Parser, printer and verifier for Synchronization Hint (2.17.12) 401 //===----------------------------------------------------------------------===// 402 403 /// Parses a Synchronization Hint clause. The value of hint is an integer 404 /// which is a combination of different hints from `omp_sync_hint_t`. 405 /// 406 /// hint-clause = `hint` `(` hint-value `)` 407 static ParseResult parseSynchronizationHint(OpAsmParser &parser, 408 IntegerAttr &hintAttr) { 409 StringRef hintKeyword; 410 int64_t hint = 0; 411 do { 412 if (failed(parser.parseKeyword(&hintKeyword))) 413 return failure(); 414 if (hintKeyword == "uncontended") 415 hint |= 1; 416 else if (hintKeyword == "contended") 417 hint |= 2; 418 else if (hintKeyword == "nonspeculative") 419 hint |= 4; 420 else if (hintKeyword == "speculative") 421 hint |= 8; 422 else 423 return parser.emitError(parser.getCurrentLocation()) 424 << hintKeyword << " is not a valid hint"; 425 } while (succeeded(parser.parseOptionalComma())); 426 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint); 427 return success(); 428 } 429 430 /// Prints a Synchronization Hint clause 431 static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, 432 IntegerAttr hintAttr) { 433 int64_t hint = hintAttr.getInt(); 434 435 if (hint == 0) 436 return; 437 438 // Helper function to get n-th bit from the right end of `value` 439 auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; 440 441 bool uncontended = bitn(hint, 0); 442 bool contended = bitn(hint, 1); 443 bool nonspeculative = bitn(hint, 2); 444 bool speculative = bitn(hint, 3); 445 446 SmallVector<StringRef> hints; 447 if (uncontended) 448 hints.push_back("uncontended"); 449 if (contended) 450 hints.push_back("contended"); 451 if (nonspeculative) 452 hints.push_back("nonspeculative"); 453 if (speculative) 454 hints.push_back("speculative"); 455 456 llvm::interleaveComma(hints, p); 457 } 458 459 /// Verifies a synchronization hint clause 460 static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) { 461 462 // Helper function to get n-th bit from the right end of `value` 463 auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; 464 465 bool uncontended = bitn(hint, 0); 466 bool contended = bitn(hint, 1); 467 bool nonspeculative = bitn(hint, 2); 468 bool speculative = bitn(hint, 3); 469 470 if (uncontended && contended) 471 return op->emitOpError() << "the hints omp_sync_hint_uncontended and " 472 "omp_sync_hint_contended cannot be combined"; 473 if (nonspeculative && speculative) 474 return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and " 475 "omp_sync_hint_speculative cannot be combined."; 476 return success(); 477 } 478 479 enum ClauseType { 480 allocateClause, 481 reductionClause, 482 nowaitClause, 483 linearClause, 484 scheduleClause, 485 collapseClause, 486 orderClause, 487 orderedClause, 488 COUNT 489 }; 490 491 //===----------------------------------------------------------------------===// 492 // Parser for Clause List 493 //===----------------------------------------------------------------------===// 494 495 /// Parse a list of clauses. The clauses can appear in any order, but their 496 /// operand segment indices are in the same order that they are passed in the 497 /// `clauses` list. The operand segments are added over the prevSegments 498 499 /// clause-list ::= clause clause-list | empty 500 /// clause ::= allocate | reduction | nowait | linear | schedule | collapse 501 /// | order | ordered 502 /// allocate ::= `allocate` `(` allocate-operand-list `)` 503 /// reduction ::= `reduction` `(` reduction-entry-list `)` 504 /// nowait ::= `nowait` 505 /// linear ::= `linear` `(` linear-list `)` 506 /// schedule ::= `schedule` `(` sched-list `)` 507 /// collapse ::= `collapse` `(` ssa-id-and-type `)` 508 /// order ::= `order` `(` `concurrent` `)` 509 /// ordered ::= `ordered` `(` ssa-id-and-type `)` 510 /// 511 /// Note that each clause can only appear once in the clase-list. 512 static ParseResult parseClauses(OpAsmParser &parser, OperationState &result, 513 SmallVectorImpl<ClauseType> &clauses, 514 SmallVectorImpl<int> &segments) { 515 516 // Check done[clause] to see if it has been parsed already 517 BitVector done(ClauseType::COUNT, false); 518 519 // See pos[clause] to get position of clause in operand segments 520 SmallVector<int> pos(ClauseType::COUNT, -1); 521 522 // Stores the last parsed clause keyword 523 StringRef clauseKeyword; 524 StringRef opName = result.name.getStringRef(); 525 526 // Containers for storing operands, types and attributes for various clauses 527 SmallVector<OpAsmParser::OperandType> allocates, allocators; 528 SmallVector<Type> allocateTypes, allocatorTypes; 529 530 ArrayAttr reductions; 531 SmallVector<OpAsmParser::OperandType> reductionVars; 532 SmallVector<Type> reductionVarTypes; 533 534 SmallVector<OpAsmParser::OperandType> linears; 535 SmallVector<Type> linearTypes; 536 SmallVector<OpAsmParser::OperandType> linearSteps; 537 538 SmallString<8> schedule; 539 SmallVector<SmallString<12>> modifiers; 540 Optional<OpAsmParser::OperandType> scheduleChunkSize; 541 Type scheduleChunkType; 542 543 // Compute the position of clauses in operand segments 544 int currPos = 0; 545 for (ClauseType clause : clauses) { 546 547 // Skip the following clauses - they do not take any position in operand 548 // segments 549 if (clause == nowaitClause || clause == collapseClause || 550 clause == orderClause || clause == orderedClause) 551 continue; 552 553 pos[clause] = currPos++; 554 555 // For the following clauses, two positions are reserved in the operand 556 // segments 557 if (clause == allocateClause || clause == linearClause) 558 currPos++; 559 } 560 561 SmallVector<int> clauseSegments(currPos); 562 563 // Helper function to check if a clause is allowed/repeated or not 564 auto checkAllowed = [&](ClauseType clause) -> ParseResult { 565 if (!llvm::is_contained(clauses, clause)) 566 return parser.emitError(parser.getCurrentLocation()) 567 << clauseKeyword << " is not a valid clause for the " << opName 568 << " operation"; 569 if (done[clause]) 570 return parser.emitError(parser.getCurrentLocation()) 571 << "at most one " << clauseKeyword << " clause can appear on the " 572 << opName << " operation"; 573 done[clause] = true; 574 return success(); 575 }; 576 577 while (succeeded(parser.parseOptionalKeyword(&clauseKeyword))) { 578 if (clauseKeyword == "allocate") { 579 if (checkAllowed(allocateClause) || parser.parseLParen() || 580 parseAllocateAndAllocator(parser, allocates, allocateTypes, 581 allocators, allocatorTypes) || 582 parser.parseRParen()) 583 return failure(); 584 clauseSegments[pos[allocateClause]] = allocates.size(); 585 clauseSegments[pos[allocateClause] + 1] = allocators.size(); 586 } else if (clauseKeyword == "reduction") { 587 if (checkAllowed(reductionClause) || parser.parseLParen() || 588 parseReductionVarList(parser, reductionVars, reductionVarTypes, 589 reductions) || 590 parser.parseRParen()) 591 return failure(); 592 clauseSegments[pos[reductionClause]] = reductionVars.size(); 593 } else if (clauseKeyword == "nowait") { 594 if (checkAllowed(nowaitClause)) 595 return failure(); 596 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 597 result.addAttribute("nowait", attr); 598 } else if (clauseKeyword == "linear") { 599 if (checkAllowed(linearClause) || 600 parseLinearClause(parser, linears, linearTypes, linearSteps)) 601 return failure(); 602 clauseSegments[pos[linearClause]] = linears.size(); 603 clauseSegments[pos[linearClause] + 1] = linearSteps.size(); 604 } else if (clauseKeyword == "schedule") { 605 if (checkAllowed(scheduleClause) || 606 parseScheduleClause(parser, schedule, modifiers, scheduleChunkSize, 607 scheduleChunkType)) 608 return failure(); 609 if (scheduleChunkSize) { 610 clauseSegments[pos[scheduleClause]] = 1; 611 } 612 } else if (clauseKeyword == "collapse") { 613 auto type = parser.getBuilder().getI64Type(); 614 mlir::IntegerAttr attr; 615 if (checkAllowed(collapseClause) || parser.parseLParen() || 616 parser.parseAttribute(attr, type) || parser.parseRParen()) 617 return failure(); 618 result.addAttribute("collapse_val", attr); 619 } else if (clauseKeyword == "ordered") { 620 mlir::IntegerAttr attr; 621 if (checkAllowed(orderedClause)) 622 return failure(); 623 if (succeeded(parser.parseOptionalLParen())) { 624 auto type = parser.getBuilder().getI64Type(); 625 if (parser.parseAttribute(attr, type) || parser.parseRParen()) 626 return failure(); 627 } else { 628 // Use 0 to represent no ordered parameter was specified 629 attr = parser.getBuilder().getI64IntegerAttr(0); 630 } 631 result.addAttribute("ordered_val", attr); 632 } else if (clauseKeyword == "order") { 633 ClauseOrderKindAttr order; 634 if (checkAllowed(orderClause) || parser.parseLParen() || 635 parseClauseAttr<ClauseOrderKindAttr>(parser, order) || 636 parser.parseRParen()) 637 return failure(); 638 result.addAttribute("order_val", order); 639 } else { 640 return parser.emitError(parser.getNameLoc()) 641 << clauseKeyword << " is not a valid clause"; 642 } 643 } 644 645 // Add allocate parameters. 646 if (done[allocateClause] && clauseSegments[pos[allocateClause]] && 647 failed(parser.resolveOperands(allocates, allocateTypes, 648 allocates[0].location, result.operands))) 649 return failure(); 650 651 // Add allocator parameters. 652 if (done[allocateClause] && clauseSegments[pos[allocateClause] + 1] && 653 failed(parser.resolveOperands(allocators, allocatorTypes, 654 allocators[0].location, result.operands))) 655 return failure(); 656 657 // Add reduction parameters and symbols 658 if (done[reductionClause] && clauseSegments[pos[reductionClause]]) { 659 if (failed(parser.resolveOperands(reductionVars, reductionVarTypes, 660 parser.getNameLoc(), result.operands))) 661 return failure(); 662 result.addAttribute("reductions", reductions); 663 } 664 665 // Add linear parameters 666 if (done[linearClause] && clauseSegments[pos[linearClause]]) { 667 auto linearStepType = parser.getBuilder().getI32Type(); 668 SmallVector<Type> linearStepTypes(linearSteps.size(), linearStepType); 669 if (failed(parser.resolveOperands(linears, linearTypes, linears[0].location, 670 result.operands)) || 671 failed(parser.resolveOperands(linearSteps, linearStepTypes, 672 linearSteps[0].location, 673 result.operands))) 674 return failure(); 675 } 676 677 // Add schedule parameters 678 if (done[scheduleClause] && !schedule.empty()) { 679 if (Optional<ClauseScheduleKind> sched = 680 symbolizeClauseScheduleKind(schedule)) { 681 auto attr = ClauseScheduleKindAttr::get(parser.getContext(), *sched); 682 result.addAttribute("schedule_val", attr); 683 } else { 684 return parser.emitError(parser.getCurrentLocation(), 685 "invalid schedule kind"); 686 } 687 if (!modifiers.empty()) { 688 SMLoc loc = parser.getCurrentLocation(); 689 if (Optional<ScheduleModifier> mod = 690 symbolizeScheduleModifier(modifiers[0])) { 691 result.addAttribute( 692 "schedule_modifier", 693 ScheduleModifierAttr::get(parser.getContext(), *mod)); 694 } else { 695 return parser.emitError(loc, "invalid schedule modifier"); 696 } 697 // Only SIMD attribute is allowed here! 698 if (modifiers.size() > 1) { 699 assert(symbolizeScheduleModifier(modifiers[1]) == 700 ScheduleModifier::simd); 701 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 702 result.addAttribute("simd_modifier", attr); 703 } 704 } 705 if (scheduleChunkSize) 706 parser.resolveOperand(*scheduleChunkSize, scheduleChunkType, 707 result.operands); 708 } 709 710 segments.insert(segments.end(), clauseSegments.begin(), clauseSegments.end()); 711 712 return success(); 713 } 714 715 //===----------------------------------------------------------------------===// 716 // Verifier for SectionsOp 717 //===----------------------------------------------------------------------===// 718 719 LogicalResult SectionsOp::verify() { 720 if (allocate_vars().size() != allocators_vars().size()) 721 return emitError( 722 "expected equal sizes for allocate and allocator variables"); 723 724 return verifyReductionVarList(*this, reductions(), reduction_vars()); 725 } 726 727 LogicalResult SectionsOp::verifyRegions() { 728 for (auto &inst : *region().begin()) { 729 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) { 730 return emitOpError() 731 << "expected omp.section op or terminator op inside region"; 732 } 733 } 734 735 return success(); 736 } 737 738 /// Parses an OpenMP Workshare Loop operation 739 /// 740 /// wsloop ::= `omp.wsloop` loop-control clause-list 741 /// loop-control ::= `(` ssa-id-list `)` `:` type `=` loop-bounds 742 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` inclusive? steps 743 /// steps := `step` `(`ssa-id-list`)` 744 /// clause-list ::= clause clause-list | empty 745 /// clause ::= linear | schedule | collapse | nowait | ordered | order 746 /// | reduction 747 ParseResult WsLoopOp::parse(OpAsmParser &parser, OperationState &result) { 748 // Parse an opening `(` followed by induction variables followed by `)` 749 SmallVector<OpAsmParser::OperandType> ivs; 750 if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1, 751 OpAsmParser::Delimiter::Paren)) 752 return failure(); 753 754 int numIVs = static_cast<int>(ivs.size()); 755 Type loopVarType; 756 if (parser.parseColonType(loopVarType)) 757 return failure(); 758 759 // Parse loop bounds. 760 SmallVector<OpAsmParser::OperandType> lower; 761 if (parser.parseEqual() || 762 parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) || 763 parser.resolveOperands(lower, loopVarType, result.operands)) 764 return failure(); 765 766 SmallVector<OpAsmParser::OperandType> upper; 767 if (parser.parseKeyword("to") || 768 parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) || 769 parser.resolveOperands(upper, loopVarType, result.operands)) 770 return failure(); 771 772 if (succeeded(parser.parseOptionalKeyword("inclusive"))) { 773 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 774 result.addAttribute("inclusive", attr); 775 } 776 777 // Parse step values. 778 SmallVector<OpAsmParser::OperandType> steps; 779 if (parser.parseKeyword("step") || 780 parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) || 781 parser.resolveOperands(steps, loopVarType, result.operands)) 782 return failure(); 783 784 SmallVector<ClauseType> clauses = { 785 linearClause, reductionClause, collapseClause, orderClause, 786 orderedClause, nowaitClause, scheduleClause}; 787 SmallVector<int> segments{numIVs, numIVs, numIVs}; 788 if (failed(parseClauses(parser, result, clauses, segments))) 789 return failure(); 790 791 result.addAttribute("operand_segment_sizes", 792 parser.getBuilder().getI32VectorAttr(segments)); 793 794 // Now parse the body. 795 Region *body = result.addRegion(); 796 SmallVector<Type> ivTypes(numIVs, loopVarType); 797 SmallVector<OpAsmParser::OperandType> blockArgs(ivs); 798 if (parser.parseRegion(*body, blockArgs, ivTypes)) 799 return failure(); 800 return success(); 801 } 802 803 void WsLoopOp::print(OpAsmPrinter &p) { 804 auto args = getRegion().front().getArguments(); 805 p << " (" << args << ") : " << args[0].getType() << " = (" << lowerBound() 806 << ") to (" << upperBound() << ") "; 807 if (inclusive()) { 808 p << "inclusive "; 809 } 810 p << "step (" << step() << ") "; 811 812 if (!linear_vars().empty()) 813 printLinearClause(p, linear_vars(), linear_step_vars()); 814 815 if (auto sched = schedule_val()) 816 printScheduleClause(p, sched.getValue(), schedule_modifier(), 817 simd_modifier(), schedule_chunk_var()); 818 819 if (auto collapse = collapse_val()) 820 p << "collapse(" << collapse << ") "; 821 822 if (nowait()) 823 p << "nowait "; 824 825 if (auto ordered = ordered_val()) 826 p << "ordered(" << ordered << ") "; 827 828 if (auto order = order_val()) 829 p << "order(" << stringifyClauseOrderKind(*order) << ") "; 830 831 if (!reduction_vars().empty()) { 832 printReductionVarList(p << "reduction(", *this, reduction_vars(), 833 reduction_vars().getTypes(), reductions()); 834 p << ")"; 835 } 836 837 p << ' '; 838 p.printRegion(region(), /*printEntryBlockArgs=*/false); 839 } 840 841 //===----------------------------------------------------------------------===// 842 // SimdLoopOp 843 //===----------------------------------------------------------------------===// 844 /// Parses an OpenMP Simd construct [2.9.3.1] 845 /// 846 /// simdloop ::= `omp.simdloop` loop-control clause-list 847 /// loop-control ::= `(` ssa-id-list `)` `:` type `=` loop-bounds 848 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` steps 849 /// steps := `step` `(`ssa-id-list`)` 850 /// clause-list ::= clause clause-list | empty 851 /// clause ::= TODO 852 ParseResult SimdLoopOp::parse(OpAsmParser &parser, OperationState &result) { 853 // Parse an opening `(` followed by induction variables followed by `)` 854 SmallVector<OpAsmParser::OperandType> ivs; 855 if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1, 856 OpAsmParser::Delimiter::Paren)) 857 return failure(); 858 int numIVs = static_cast<int>(ivs.size()); 859 Type loopVarType; 860 if (parser.parseColonType(loopVarType)) 861 return failure(); 862 // Parse loop bounds. 863 SmallVector<OpAsmParser::OperandType> lower; 864 if (parser.parseEqual() || 865 parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) || 866 parser.resolveOperands(lower, loopVarType, result.operands)) 867 return failure(); 868 SmallVector<OpAsmParser::OperandType> upper; 869 if (parser.parseKeyword("to") || 870 parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) || 871 parser.resolveOperands(upper, loopVarType, result.operands)) 872 return failure(); 873 874 // Parse step values. 875 SmallVector<OpAsmParser::OperandType> steps; 876 if (parser.parseKeyword("step") || 877 parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) || 878 parser.resolveOperands(steps, loopVarType, result.operands)) 879 return failure(); 880 881 SmallVector<int> segments{numIVs, numIVs, numIVs}; 882 // TODO: Add parseClauses() when we support clauses 883 result.addAttribute("operand_segment_sizes", 884 parser.getBuilder().getI32VectorAttr(segments)); 885 886 // Now parse the body. 887 Region *body = result.addRegion(); 888 SmallVector<Type> ivTypes(numIVs, loopVarType); 889 SmallVector<OpAsmParser::OperandType> blockArgs(ivs); 890 if (parser.parseRegion(*body, blockArgs, ivTypes)) 891 return failure(); 892 return success(); 893 } 894 895 void SimdLoopOp::print(OpAsmPrinter &p) { 896 auto args = getRegion().front().getArguments(); 897 p << " (" << args << ") : " << args[0].getType() << " = (" << lowerBound() 898 << ") to (" << upperBound() << ") "; 899 p << "step (" << step() << ") "; 900 901 p.printRegion(region(), /*printEntryBlockArgs=*/false); 902 } 903 904 //===----------------------------------------------------------------------===// 905 // Verifier for Simd construct [2.9.3.1] 906 //===----------------------------------------------------------------------===// 907 908 LogicalResult SimdLoopOp::verify() { 909 if (this->lowerBound().empty()) { 910 return emitOpError() << "empty lowerbound for simd loop operation"; 911 } 912 return success(); 913 } 914 915 //===----------------------------------------------------------------------===// 916 // ReductionOp 917 //===----------------------------------------------------------------------===// 918 919 static ParseResult parseAtomicReductionRegion(OpAsmParser &parser, 920 Region ®ion) { 921 if (parser.parseOptionalKeyword("atomic")) 922 return success(); 923 return parser.parseRegion(region); 924 } 925 926 static void printAtomicReductionRegion(OpAsmPrinter &printer, 927 ReductionDeclareOp op, Region ®ion) { 928 if (region.empty()) 929 return; 930 printer << "atomic "; 931 printer.printRegion(region); 932 } 933 934 LogicalResult ReductionDeclareOp::verifyRegions() { 935 if (initializerRegion().empty()) 936 return emitOpError() << "expects non-empty initializer region"; 937 Block &initializerEntryBlock = initializerRegion().front(); 938 if (initializerEntryBlock.getNumArguments() != 1 || 939 initializerEntryBlock.getArgument(0).getType() != type()) { 940 return emitOpError() << "expects initializer region with one argument " 941 "of the reduction type"; 942 } 943 944 for (YieldOp yieldOp : initializerRegion().getOps<YieldOp>()) { 945 if (yieldOp.results().size() != 1 || 946 yieldOp.results().getTypes()[0] != type()) 947 return emitOpError() << "expects initializer region to yield a value " 948 "of the reduction type"; 949 } 950 951 if (reductionRegion().empty()) 952 return emitOpError() << "expects non-empty reduction region"; 953 Block &reductionEntryBlock = reductionRegion().front(); 954 if (reductionEntryBlock.getNumArguments() != 2 || 955 reductionEntryBlock.getArgumentTypes()[0] != 956 reductionEntryBlock.getArgumentTypes()[1] || 957 reductionEntryBlock.getArgumentTypes()[0] != type()) 958 return emitOpError() << "expects reduction region with two arguments of " 959 "the reduction type"; 960 for (YieldOp yieldOp : reductionRegion().getOps<YieldOp>()) { 961 if (yieldOp.results().size() != 1 || 962 yieldOp.results().getTypes()[0] != type()) 963 return emitOpError() << "expects reduction region to yield a value " 964 "of the reduction type"; 965 } 966 967 if (atomicReductionRegion().empty()) 968 return success(); 969 970 Block &atomicReductionEntryBlock = atomicReductionRegion().front(); 971 if (atomicReductionEntryBlock.getNumArguments() != 2 || 972 atomicReductionEntryBlock.getArgumentTypes()[0] != 973 atomicReductionEntryBlock.getArgumentTypes()[1]) 974 return emitOpError() << "expects atomic reduction region with two " 975 "arguments of the same type"; 976 auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0] 977 .dyn_cast<PointerLikeType>(); 978 if (!ptrType || ptrType.getElementType() != type()) 979 return emitOpError() << "expects atomic reduction region arguments to " 980 "be accumulators containing the reduction type"; 981 return success(); 982 } 983 984 LogicalResult ReductionOp::verify() { 985 // TODO: generalize this to an op interface when there is more than one op 986 // that supports reductions. 987 auto container = (*this)->getParentOfType<WsLoopOp>(); 988 for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i) 989 if (container.reduction_vars()[i] == accumulator()) 990 return success(); 991 992 return emitOpError() << "the accumulator is not used by the parent"; 993 } 994 995 //===----------------------------------------------------------------------===// 996 // WsLoopOp 997 //===----------------------------------------------------------------------===// 998 999 void WsLoopOp::build(OpBuilder &builder, OperationState &state, 1000 ValueRange lowerBound, ValueRange upperBound, 1001 ValueRange step, ArrayRef<NamedAttribute> attributes) { 1002 build(builder, state, lowerBound, upperBound, step, 1003 /*linear_vars=*/ValueRange(), 1004 /*linear_step_vars=*/ValueRange(), /*reduction_vars=*/ValueRange(), 1005 /*reductions=*/nullptr, /*schedule_val=*/nullptr, 1006 /*schedule_chunk_var=*/nullptr, /*schedule_modifier=*/nullptr, 1007 /*simd_modifier=*/false, /*collapse_val=*/nullptr, /*nowait=*/false, 1008 /*ordered_val=*/nullptr, /*order_val=*/nullptr, /*inclusive=*/false); 1009 state.addAttributes(attributes); 1010 } 1011 1012 LogicalResult WsLoopOp::verify() { 1013 return verifyReductionVarList(*this, reductions(), reduction_vars()); 1014 } 1015 1016 //===----------------------------------------------------------------------===// 1017 // Verifier for critical construct (2.17.1) 1018 //===----------------------------------------------------------------------===// 1019 1020 LogicalResult CriticalDeclareOp::verify() { 1021 return verifySynchronizationHint(*this, hint_val()); 1022 } 1023 1024 LogicalResult 1025 CriticalOp::verifySymbolUses(SymbolTableCollection &symbol_table) { 1026 if (nameAttr()) { 1027 SymbolRefAttr symbolRef = nameAttr(); 1028 auto decl = symbol_table.lookupNearestSymbolFrom<CriticalDeclareOp>( 1029 *this, symbolRef); 1030 if (!decl) { 1031 return emitOpError() << "expected symbol reference " << symbolRef 1032 << " to point to a critical declaration"; 1033 } 1034 } 1035 1036 return success(); 1037 } 1038 1039 //===----------------------------------------------------------------------===// 1040 // Verifier for ordered construct 1041 //===----------------------------------------------------------------------===// 1042 1043 LogicalResult OrderedOp::verify() { 1044 auto container = (*this)->getParentOfType<WsLoopOp>(); 1045 if (!container || !container.ordered_valAttr() || 1046 container.ordered_valAttr().getInt() == 0) 1047 return emitOpError() << "ordered depend directive must be closely " 1048 << "nested inside a worksharing-loop with ordered " 1049 << "clause with parameter present"; 1050 1051 if (container.ordered_valAttr().getInt() != 1052 (int64_t)num_loops_val().getValue()) 1053 return emitOpError() << "number of variables in depend clause does not " 1054 << "match number of iteration variables in the " 1055 << "doacross loop"; 1056 1057 return success(); 1058 } 1059 1060 LogicalResult OrderedRegionOp::verify() { 1061 // TODO: The code generation for ordered simd directive is not supported yet. 1062 if (simd()) 1063 return failure(); 1064 1065 if (auto container = (*this)->getParentOfType<WsLoopOp>()) { 1066 if (!container.ordered_valAttr() || 1067 container.ordered_valAttr().getInt() != 0) 1068 return emitOpError() << "ordered region must be closely nested inside " 1069 << "a worksharing-loop region with an ordered " 1070 << "clause without parameter present"; 1071 } 1072 1073 return success(); 1074 } 1075 1076 //===----------------------------------------------------------------------===// 1077 // Verifier for AtomicReadOp 1078 //===----------------------------------------------------------------------===// 1079 1080 LogicalResult AtomicReadOp::verify() { 1081 if (auto mo = memory_order_val()) { 1082 if (*mo == ClauseMemoryOrderKind::Acq_rel || 1083 *mo == ClauseMemoryOrderKind::Release) { 1084 return emitError( 1085 "memory-order must not be acq_rel or release for atomic reads"); 1086 } 1087 } 1088 if (x() == v()) 1089 return emitError( 1090 "read and write must not be to the same location for atomic reads"); 1091 return verifySynchronizationHint(*this, hint_val()); 1092 } 1093 1094 //===----------------------------------------------------------------------===// 1095 // Verifier for AtomicWriteOp 1096 //===----------------------------------------------------------------------===// 1097 1098 LogicalResult AtomicWriteOp::verify() { 1099 if (auto mo = memory_order_val()) { 1100 if (*mo == ClauseMemoryOrderKind::Acq_rel || 1101 *mo == ClauseMemoryOrderKind::Acquire) { 1102 return emitError( 1103 "memory-order must not be acq_rel or acquire for atomic writes"); 1104 } 1105 } 1106 return verifySynchronizationHint(*this, hint_val()); 1107 } 1108 1109 //===----------------------------------------------------------------------===// 1110 // Verifier for AtomicUpdateOp 1111 //===----------------------------------------------------------------------===// 1112 1113 LogicalResult AtomicUpdateOp::verify() { 1114 if (auto mo = memory_order_val()) { 1115 if (*mo == ClauseMemoryOrderKind::Acq_rel || 1116 *mo == ClauseMemoryOrderKind::Acquire) { 1117 return emitError( 1118 "memory-order must not be acq_rel or acquire for atomic updates"); 1119 } 1120 } 1121 1122 if (x().getType().cast<PointerLikeType>().getElementType() != 1123 region().getArgument(0).getType()) { 1124 return emitError("the type of the operand must be a pointer type whose " 1125 "element type is the same as that of the region argument"); 1126 } 1127 1128 return success(); 1129 } 1130 1131 LogicalResult AtomicUpdateOp::verifyRegions() { 1132 if (region().getNumArguments() != 1) 1133 return emitError("the region must accept exactly one argument"); 1134 1135 if (region().front().getOperations().size() < 2) 1136 return emitError() << "the update region must have at least two operations " 1137 "(binop and terminator)"; 1138 1139 YieldOp yieldOp = *region().getOps<YieldOp>().begin(); 1140 1141 if (yieldOp.results().size() != 1) 1142 return emitError("only updated value must be returned"); 1143 if (yieldOp.results().front().getType() != region().getArgument(0).getType()) 1144 return emitError("input and yielded value must have the same type"); 1145 return success(); 1146 } 1147 1148 //===----------------------------------------------------------------------===// 1149 // Verifier for AtomicCaptureOp 1150 //===----------------------------------------------------------------------===// 1151 1152 LogicalResult AtomicCaptureOp::verifyRegions() { 1153 Block::OpListType &ops = region().front().getOperations(); 1154 if (ops.size() != 3) 1155 return emitError() 1156 << "expected three operations in omp.atomic.capture region (one " 1157 "terminator, and two atomic ops)"; 1158 auto &firstOp = ops.front(); 1159 auto &secondOp = *ops.getNextNode(firstOp); 1160 auto firstReadStmt = dyn_cast<AtomicReadOp>(firstOp); 1161 auto firstUpdateStmt = dyn_cast<AtomicUpdateOp>(firstOp); 1162 auto secondReadStmt = dyn_cast<AtomicReadOp>(secondOp); 1163 auto secondUpdateStmt = dyn_cast<AtomicUpdateOp>(secondOp); 1164 auto secondWriteStmt = dyn_cast<AtomicWriteOp>(secondOp); 1165 1166 if (!((firstUpdateStmt && secondReadStmt) || 1167 (firstReadStmt && secondUpdateStmt) || 1168 (firstReadStmt && secondWriteStmt))) 1169 return ops.front().emitError() 1170 << "invalid sequence of operations in the capture region"; 1171 if (firstUpdateStmt && secondReadStmt && 1172 firstUpdateStmt.x() != secondReadStmt.x()) 1173 return firstUpdateStmt.emitError() 1174 << "updated variable in omp.atomic.update must be captured in " 1175 "second operation"; 1176 if (firstReadStmt && secondUpdateStmt && 1177 firstReadStmt.x() != secondUpdateStmt.x()) 1178 return firstReadStmt.emitError() 1179 << "captured variable in omp.atomic.read must be updated in second " 1180 "operation"; 1181 if (firstReadStmt && secondWriteStmt && 1182 firstReadStmt.x() != secondWriteStmt.address()) 1183 return firstReadStmt.emitError() 1184 << "captured variable in omp.atomic.read must be updated in " 1185 "second operation"; 1186 return success(); 1187 } 1188 1189 #define GET_ATTRDEF_CLASSES 1190 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" 1191 1192 #define GET_OP_CLASSES 1193 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 1194