1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the OpenMP dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 15 #include "mlir/IR/Attributes.h" 16 #include "mlir/IR/DialectImplementation.h" 17 #include "mlir/IR/OpImplementation.h" 18 #include "mlir/IR/OperationSupport.h" 19 20 #include "llvm/ADT/BitVector.h" 21 #include "llvm/ADT/SmallString.h" 22 #include "llvm/ADT/StringExtras.h" 23 #include "llvm/ADT/StringRef.h" 24 #include "llvm/ADT/StringSwitch.h" 25 #include "llvm/ADT/TypeSwitch.h" 26 #include <cstddef> 27 28 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc" 29 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc" 30 #include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc" 31 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc" 32 33 using namespace mlir; 34 using namespace mlir::omp; 35 36 namespace { 37 /// Model for pointer-like types that already provide a `getElementType` method. 38 template <typename T> 39 struct PointerLikeModel 40 : public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> { 41 Type getElementType(Type pointer) const { 42 return pointer.cast<T>().getElementType(); 43 } 44 }; 45 } // namespace 46 47 void OpenMPDialect::initialize() { 48 addOperations< 49 #define GET_OP_LIST 50 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 51 >(); 52 addAttributes< 53 #define GET_ATTRDEF_LIST 54 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" 55 >(); 56 57 LLVM::LLVMPointerType::attachInterface< 58 PointerLikeModel<LLVM::LLVMPointerType>>(*getContext()); 59 MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext()); 60 } 61 62 //===----------------------------------------------------------------------===// 63 // Parser and printer for Allocate Clause 64 //===----------------------------------------------------------------------===// 65 66 /// Parse an allocate clause with allocators and a list of operands with types. 67 /// 68 /// allocate-operand-list :: = allocate-operand | 69 /// allocator-operand `,` allocate-operand-list 70 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type 71 /// ssa-id-and-type ::= ssa-id `:` type 72 static ParseResult parseAllocateAndAllocator( 73 OpAsmParser &parser, 74 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operandsAllocate, 75 SmallVectorImpl<Type> &typesAllocate, 76 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operandsAllocator, 77 SmallVectorImpl<Type> &typesAllocator) { 78 79 return parser.parseCommaSeparatedList([&]() { 80 OpAsmParser::UnresolvedOperand operand; 81 Type type; 82 if (parser.parseOperand(operand) || parser.parseColonType(type)) 83 return failure(); 84 operandsAllocator.push_back(operand); 85 typesAllocator.push_back(type); 86 if (parser.parseArrow()) 87 return failure(); 88 if (parser.parseOperand(operand) || parser.parseColonType(type)) 89 return failure(); 90 91 operandsAllocate.push_back(operand); 92 typesAllocate.push_back(type); 93 return success(); 94 }); 95 } 96 97 /// Print allocate clause 98 static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, 99 OperandRange varsAllocate, 100 TypeRange typesAllocate, 101 OperandRange varsAllocator, 102 TypeRange typesAllocator) { 103 for (unsigned i = 0; i < varsAllocate.size(); ++i) { 104 std::string separator = i == varsAllocate.size() - 1 ? "" : ", "; 105 p << varsAllocator[i] << " : " << typesAllocator[i] << " -> "; 106 p << varsAllocate[i] << " : " << typesAllocate[i] << separator; 107 } 108 } 109 110 //===----------------------------------------------------------------------===// 111 // Parser and printer for a clause attribute (StringEnumAttr) 112 //===----------------------------------------------------------------------===// 113 114 template <typename ClauseAttr> 115 static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) { 116 using ClauseT = decltype(std::declval<ClauseAttr>().getValue()); 117 StringRef enumStr; 118 SMLoc loc = parser.getCurrentLocation(); 119 if (parser.parseKeyword(&enumStr)) 120 return failure(); 121 if (Optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) { 122 attr = ClauseAttr::get(parser.getContext(), *enumValue); 123 return success(); 124 } 125 return parser.emitError(loc, "invalid clause value: '") << enumStr << "'"; 126 } 127 128 template <typename ClauseAttr> 129 void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) { 130 p << stringifyEnum(attr.getValue()); 131 } 132 133 //===----------------------------------------------------------------------===// 134 // Parser and printer for Linear Clause 135 //===----------------------------------------------------------------------===// 136 137 /// linear ::= `linear` `(` linear-list `)` 138 /// linear-list := linear-val | linear-val linear-list 139 /// linear-val := ssa-id-and-type `=` ssa-id-and-type 140 static ParseResult 141 parseLinearClause(OpAsmParser &parser, 142 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars, 143 SmallVectorImpl<Type> &types, 144 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &stepVars) { 145 return parser.parseCommaSeparatedList([&]() { 146 OpAsmParser::UnresolvedOperand var; 147 Type type; 148 OpAsmParser::UnresolvedOperand stepVar; 149 if (parser.parseOperand(var) || parser.parseEqual() || 150 parser.parseOperand(stepVar) || parser.parseColonType(type)) 151 return failure(); 152 153 vars.push_back(var); 154 types.push_back(type); 155 stepVars.push_back(stepVar); 156 return success(); 157 }); 158 } 159 160 /// Print Linear Clause 161 static void printLinearClause(OpAsmPrinter &p, Operation *op, 162 ValueRange linearVars, TypeRange linearVarTypes, 163 ValueRange linearStepVars) { 164 size_t linearVarsSize = linearVars.size(); 165 for (unsigned i = 0; i < linearVarsSize; ++i) { 166 std::string separator = i == linearVarsSize - 1 ? "" : ", "; 167 p << linearVars[i]; 168 if (linearStepVars.size() > i) 169 p << " = " << linearStepVars[i]; 170 p << " : " << linearVars[i].getType() << separator; 171 } 172 } 173 174 //===----------------------------------------------------------------------===// 175 // Parser, printer and verifier for Schedule Clause 176 //===----------------------------------------------------------------------===// 177 178 static ParseResult 179 verifyScheduleModifiers(OpAsmParser &parser, 180 SmallVectorImpl<SmallString<12>> &modifiers) { 181 if (modifiers.size() > 2) 182 return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)"; 183 for (const auto &mod : modifiers) { 184 // Translate the string. If it has no value, then it was not a valid 185 // modifier! 186 auto symbol = symbolizeScheduleModifier(mod); 187 if (!symbol) 188 return parser.emitError(parser.getNameLoc()) 189 << " unknown modifier type: " << mod; 190 } 191 192 // If we have one modifier that is "simd", then stick a "none" modiifer in 193 // index 0. 194 if (modifiers.size() == 1) { 195 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) { 196 modifiers.push_back(modifiers[0]); 197 modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none); 198 } 199 } else if (modifiers.size() == 2) { 200 // If there are two modifier: 201 // First modifier should not be simd, second one should be simd 202 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd || 203 symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd) 204 return parser.emitError(parser.getNameLoc()) 205 << " incorrect modifier order"; 206 } 207 return success(); 208 } 209 210 /// schedule ::= `schedule` `(` sched-list `)` 211 /// sched-list ::= sched-val | sched-val sched-list | 212 /// sched-val `,` sched-modifier 213 /// sched-val ::= sched-with-chunk | sched-wo-chunk 214 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)? 215 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided` 216 /// sched-wo-chunk ::= `auto` | `runtime` 217 /// sched-modifier ::= sched-mod-val | sched-mod-val `,` sched-mod-val 218 /// sched-mod-val ::= `monotonic` | `nonmonotonic` | `simd` | `none` 219 static ParseResult parseScheduleClause( 220 OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, 221 ScheduleModifierAttr &scheduleModifier, UnitAttr &simdModifier, 222 Optional<OpAsmParser::UnresolvedOperand> &chunkSize, Type &chunkType) { 223 StringRef keyword; 224 if (parser.parseKeyword(&keyword)) 225 return failure(); 226 llvm::Optional<mlir::omp::ClauseScheduleKind> schedule = 227 symbolizeClauseScheduleKind(keyword); 228 if (!schedule) 229 return parser.emitError(parser.getNameLoc()) << " expected schedule kind"; 230 231 scheduleAttr = ClauseScheduleKindAttr::get(parser.getContext(), *schedule); 232 switch (*schedule) { 233 case ClauseScheduleKind::Static: 234 case ClauseScheduleKind::Dynamic: 235 case ClauseScheduleKind::Guided: 236 if (succeeded(parser.parseOptionalEqual())) { 237 chunkSize = OpAsmParser::UnresolvedOperand{}; 238 if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType)) 239 return failure(); 240 } else { 241 chunkSize = llvm::NoneType::None; 242 } 243 break; 244 case ClauseScheduleKind::Auto: 245 case ClauseScheduleKind::Runtime: 246 chunkSize = llvm::NoneType::None; 247 } 248 249 // If there is a comma, we have one or more modifiers.. 250 SmallVector<SmallString<12>> modifiers; 251 while (succeeded(parser.parseOptionalComma())) { 252 StringRef mod; 253 if (parser.parseKeyword(&mod)) 254 return failure(); 255 modifiers.push_back(mod); 256 } 257 258 if (verifyScheduleModifiers(parser, modifiers)) 259 return failure(); 260 261 if (!modifiers.empty()) { 262 SMLoc loc = parser.getCurrentLocation(); 263 if (Optional<ScheduleModifier> mod = 264 symbolizeScheduleModifier(modifiers[0])) { 265 scheduleModifier = ScheduleModifierAttr::get(parser.getContext(), *mod); 266 } else { 267 return parser.emitError(loc, "invalid schedule modifier"); 268 } 269 // Only SIMD attribute is allowed here! 270 if (modifiers.size() > 1) { 271 assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd); 272 simdModifier = UnitAttr::get(parser.getBuilder().getContext()); 273 } 274 } 275 276 return success(); 277 } 278 279 /// Print schedule clause 280 static void printScheduleClause(OpAsmPrinter &p, Operation *op, 281 ClauseScheduleKindAttr schedAttr, 282 ScheduleModifierAttr modifier, UnitAttr simd, 283 Value scheduleChunkVar, 284 Type scheduleChunkType) { 285 p << stringifyClauseScheduleKind(schedAttr.getValue()); 286 if (scheduleChunkVar) 287 p << " = " << scheduleChunkVar << " : " << scheduleChunkVar.getType(); 288 if (modifier) 289 p << ", " << stringifyScheduleModifier(modifier.getValue()); 290 if (simd) 291 p << ", simd"; 292 } 293 294 //===----------------------------------------------------------------------===// 295 // Parser, printer and verifier for ReductionVarList 296 //===----------------------------------------------------------------------===// 297 298 /// reduction-entry-list ::= reduction-entry 299 /// | reduction-entry-list `,` reduction-entry 300 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type 301 static ParseResult 302 parseReductionVarList(OpAsmParser &parser, 303 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands, 304 SmallVectorImpl<Type> &types, 305 ArrayAttr &redcuctionSymbols) { 306 SmallVector<SymbolRefAttr> reductionVec; 307 if (failed(parser.parseCommaSeparatedList([&]() { 308 if (parser.parseAttribute(reductionVec.emplace_back()) || 309 parser.parseArrow() || 310 parser.parseOperand(operands.emplace_back()) || 311 parser.parseColonType(types.emplace_back())) 312 return failure(); 313 return success(); 314 }))) 315 return failure(); 316 SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end()); 317 redcuctionSymbols = ArrayAttr::get(parser.getContext(), reductions); 318 return success(); 319 } 320 321 /// Print Reduction clause 322 static void printReductionVarList(OpAsmPrinter &p, Operation *op, 323 OperandRange reductionVars, 324 TypeRange reductionTypes, 325 Optional<ArrayAttr> reductions) { 326 for (unsigned i = 0, e = reductions->size(); i < e; ++i) { 327 if (i != 0) 328 p << ", "; 329 p << (*reductions)[i] << " -> " << reductionVars[i] << " : " 330 << reductionVars[i].getType(); 331 } 332 } 333 334 /// Verifies Reduction Clause 335 static LogicalResult verifyReductionVarList(Operation *op, 336 Optional<ArrayAttr> reductions, 337 OperandRange reductionVars) { 338 if (!reductionVars.empty()) { 339 if (!reductions || reductions->size() != reductionVars.size()) 340 return op->emitOpError() 341 << "expected as many reduction symbol references " 342 "as reduction variables"; 343 } else { 344 if (reductions) 345 return op->emitOpError() << "unexpected reduction symbol references"; 346 return success(); 347 } 348 349 // TODO: The followings should be done in 350 // SymbolUserOpInterface::verifySymbolUses. 351 DenseSet<Value> accumulators; 352 for (auto args : llvm::zip(reductionVars, *reductions)) { 353 Value accum = std::get<0>(args); 354 355 if (!accumulators.insert(accum).second) 356 return op->emitOpError() << "accumulator variable used more than once"; 357 358 Type varType = accum.getType().cast<PointerLikeType>(); 359 auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>(); 360 auto decl = 361 SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef); 362 if (!decl) 363 return op->emitOpError() << "expected symbol reference " << symbolRef 364 << " to point to a reduction declaration"; 365 366 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType) 367 return op->emitOpError() 368 << "expected accumulator (" << varType 369 << ") to be the same type as reduction declaration (" 370 << decl.getAccumulatorType() << ")"; 371 } 372 373 return success(); 374 } 375 376 //===----------------------------------------------------------------------===// 377 // Parser, printer and verifier for Synchronization Hint (2.17.12) 378 //===----------------------------------------------------------------------===// 379 380 /// Parses a Synchronization Hint clause. The value of hint is an integer 381 /// which is a combination of different hints from `omp_sync_hint_t`. 382 /// 383 /// hint-clause = `hint` `(` hint-value `)` 384 static ParseResult parseSynchronizationHint(OpAsmParser &parser, 385 IntegerAttr &hintAttr) { 386 StringRef hintKeyword; 387 int64_t hint = 0; 388 if (succeeded(parser.parseOptionalKeyword("none"))) { 389 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0); 390 return success(); 391 } 392 auto parseKeyword = [&]() -> ParseResult { 393 if (failed(parser.parseKeyword(&hintKeyword))) 394 return failure(); 395 if (hintKeyword == "uncontended") 396 hint |= 1; 397 else if (hintKeyword == "contended") 398 hint |= 2; 399 else if (hintKeyword == "nonspeculative") 400 hint |= 4; 401 else if (hintKeyword == "speculative") 402 hint |= 8; 403 else 404 return parser.emitError(parser.getCurrentLocation()) 405 << hintKeyword << " is not a valid hint"; 406 return success(); 407 }; 408 if (parser.parseCommaSeparatedList(parseKeyword)) 409 return failure(); 410 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint); 411 return success(); 412 } 413 414 /// Prints a Synchronization Hint clause 415 static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, 416 IntegerAttr hintAttr) { 417 int64_t hint = hintAttr.getInt(); 418 419 if (hint == 0) { 420 p << "none"; 421 return; 422 } 423 424 // Helper function to get n-th bit from the right end of `value` 425 auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; 426 427 bool uncontended = bitn(hint, 0); 428 bool contended = bitn(hint, 1); 429 bool nonspeculative = bitn(hint, 2); 430 bool speculative = bitn(hint, 3); 431 432 SmallVector<StringRef> hints; 433 if (uncontended) 434 hints.push_back("uncontended"); 435 if (contended) 436 hints.push_back("contended"); 437 if (nonspeculative) 438 hints.push_back("nonspeculative"); 439 if (speculative) 440 hints.push_back("speculative"); 441 442 llvm::interleaveComma(hints, p); 443 } 444 445 /// Verifies a synchronization hint clause 446 static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) { 447 448 // Helper function to get n-th bit from the right end of `value` 449 auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; 450 451 bool uncontended = bitn(hint, 0); 452 bool contended = bitn(hint, 1); 453 bool nonspeculative = bitn(hint, 2); 454 bool speculative = bitn(hint, 3); 455 456 if (uncontended && contended) 457 return op->emitOpError() << "the hints omp_sync_hint_uncontended and " 458 "omp_sync_hint_contended cannot be combined"; 459 if (nonspeculative && speculative) 460 return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and " 461 "omp_sync_hint_speculative cannot be combined."; 462 return success(); 463 } 464 465 //===----------------------------------------------------------------------===// 466 // ParallelOp 467 //===----------------------------------------------------------------------===// 468 469 void ParallelOp::build(OpBuilder &builder, OperationState &state, 470 ArrayRef<NamedAttribute> attributes) { 471 ParallelOp::build( 472 builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr, 473 /*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(), 474 /*reduction_vars=*/ValueRange(), /*reductions=*/nullptr, 475 /*proc_bind_val=*/nullptr); 476 state.addAttributes(attributes); 477 } 478 479 LogicalResult ParallelOp::verify() { 480 if (allocate_vars().size() != allocators_vars().size()) 481 return emitError( 482 "expected equal sizes for allocate and allocator variables"); 483 return verifyReductionVarList(*this, reductions(), reduction_vars()); 484 } 485 486 //===----------------------------------------------------------------------===// 487 // Verifier for SectionsOp 488 //===----------------------------------------------------------------------===// 489 490 LogicalResult SectionsOp::verify() { 491 if (allocate_vars().size() != allocators_vars().size()) 492 return emitError( 493 "expected equal sizes for allocate and allocator variables"); 494 495 return verifyReductionVarList(*this, reductions(), reduction_vars()); 496 } 497 498 LogicalResult SectionsOp::verifyRegions() { 499 for (auto &inst : *region().begin()) { 500 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) { 501 return emitOpError() 502 << "expected omp.section op or terminator op inside region"; 503 } 504 } 505 506 return success(); 507 } 508 509 LogicalResult SingleOp::verify() { 510 // Check for allocate clause restrictions 511 if (allocate_vars().size() != allocators_vars().size()) 512 return emitError( 513 "expected equal sizes for allocate and allocator variables"); 514 515 return success(); 516 } 517 518 //===----------------------------------------------------------------------===// 519 // WsLoopOp 520 //===----------------------------------------------------------------------===// 521 522 /// loop-control ::= `(` ssa-id-list `)` `:` type `=` loop-bounds 523 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` inclusive? steps 524 /// steps := `step` `(`ssa-id-list`)` 525 ParseResult 526 parseWsLoopControl(OpAsmParser &parser, Region ®ion, 527 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &lowerBound, 528 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &upperBound, 529 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &steps, 530 SmallVectorImpl<Type> &loopVarTypes, UnitAttr &inclusive) { 531 // Parse an opening `(` followed by induction variables followed by `)` 532 SmallVector<OpAsmParser::Argument> ivs; 533 Type loopVarType; 534 if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren) || 535 parser.parseColonType(loopVarType) || 536 // Parse loop bounds. 537 parser.parseEqual() || 538 parser.parseOperandList(lowerBound, ivs.size(), 539 OpAsmParser::Delimiter::Paren) || 540 parser.parseKeyword("to") || 541 parser.parseOperandList(upperBound, ivs.size(), 542 OpAsmParser::Delimiter::Paren)) 543 return failure(); 544 545 if (succeeded(parser.parseOptionalKeyword("inclusive"))) 546 inclusive = UnitAttr::get(parser.getBuilder().getContext()); 547 548 // Parse step values. 549 if (parser.parseKeyword("step") || 550 parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren)) 551 return failure(); 552 553 // Now parse the body. 554 loopVarTypes = SmallVector<Type>(ivs.size(), loopVarType); 555 for (auto &iv : ivs) 556 iv.type = loopVarType; 557 return parser.parseRegion(region, ivs); 558 } 559 560 void printWsLoopControl(OpAsmPrinter &p, Operation *op, Region ®ion, 561 ValueRange lowerBound, ValueRange upperBound, 562 ValueRange steps, TypeRange loopVarTypes, 563 UnitAttr inclusive) { 564 auto args = region.front().getArguments(); 565 p << " (" << args << ") : " << args[0].getType() << " = (" << lowerBound 566 << ") to (" << upperBound << ") "; 567 if (inclusive) 568 p << "inclusive "; 569 p << "step (" << steps << ") "; 570 p.printRegion(region, /*printEntryBlockArgs=*/false); 571 } 572 573 //===----------------------------------------------------------------------===// 574 // SimdLoopOp 575 //===----------------------------------------------------------------------===// 576 /// Parses an OpenMP Simd construct [2.9.3.1] 577 /// 578 /// simdloop ::= `omp.simdloop` loop-control clause-list 579 /// loop-control ::= `(` ssa-id-list `)` `:` type `=` loop-bounds 580 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` steps 581 /// steps := `step` `(`ssa-id-list`)` 582 /// clause-list ::= clause clause-list | empty 583 /// clause ::= TODO 584 ParseResult SimdLoopOp::parse(OpAsmParser &parser, OperationState &result) { 585 // Parse an opening `(` followed by induction variables followed by `)` 586 SmallVector<OpAsmParser::Argument> ivs; 587 Type loopVarType; 588 SmallVector<OpAsmParser::UnresolvedOperand> lower, upper, steps; 589 if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren) || 590 parser.parseColonType(loopVarType) || 591 // Parse loop bounds. 592 parser.parseEqual() || 593 parser.parseOperandList(lower, ivs.size(), 594 OpAsmParser::Delimiter::Paren) || 595 parser.resolveOperands(lower, loopVarType, result.operands) || 596 parser.parseKeyword("to") || 597 parser.parseOperandList(upper, ivs.size(), 598 OpAsmParser::Delimiter::Paren) || 599 parser.resolveOperands(upper, loopVarType, result.operands) || 600 // Parse step values. 601 parser.parseKeyword("step") || 602 parser.parseOperandList(steps, ivs.size(), 603 OpAsmParser::Delimiter::Paren) || 604 parser.resolveOperands(steps, loopVarType, result.operands)) 605 return failure(); 606 607 int numIVs = static_cast<int>(ivs.size()); 608 SmallVector<int> segments{numIVs, numIVs, numIVs}; 609 // TODO: Add parseClauses() when we support clauses 610 result.addAttribute("operand_segment_sizes", 611 parser.getBuilder().getI32VectorAttr(segments)); 612 613 // Now parse the body. 614 Region *body = result.addRegion(); 615 for (auto &iv : ivs) 616 iv.type = loopVarType; 617 return parser.parseRegion(*body, ivs); 618 } 619 620 void SimdLoopOp::print(OpAsmPrinter &p) { 621 auto args = getRegion().front().getArguments(); 622 p << " (" << args << ") : " << args[0].getType() << " = (" << lowerBound() 623 << ") to (" << upperBound() << ") "; 624 p << "step (" << step() << ") "; 625 626 p.printRegion(region(), /*printEntryBlockArgs=*/false); 627 } 628 629 //===----------------------------------------------------------------------===// 630 // Verifier for Simd construct [2.9.3.1] 631 //===----------------------------------------------------------------------===// 632 633 LogicalResult SimdLoopOp::verify() { 634 if (this->lowerBound().empty()) { 635 return emitOpError() << "empty lowerbound for simd loop operation"; 636 } 637 return success(); 638 } 639 640 //===----------------------------------------------------------------------===// 641 // ReductionOp 642 //===----------------------------------------------------------------------===// 643 644 static ParseResult parseAtomicReductionRegion(OpAsmParser &parser, 645 Region ®ion) { 646 if (parser.parseOptionalKeyword("atomic")) 647 return success(); 648 return parser.parseRegion(region); 649 } 650 651 static void printAtomicReductionRegion(OpAsmPrinter &printer, 652 ReductionDeclareOp op, Region ®ion) { 653 if (region.empty()) 654 return; 655 printer << "atomic "; 656 printer.printRegion(region); 657 } 658 659 LogicalResult ReductionDeclareOp::verifyRegions() { 660 if (initializerRegion().empty()) 661 return emitOpError() << "expects non-empty initializer region"; 662 Block &initializerEntryBlock = initializerRegion().front(); 663 if (initializerEntryBlock.getNumArguments() != 1 || 664 initializerEntryBlock.getArgument(0).getType() != type()) { 665 return emitOpError() << "expects initializer region with one argument " 666 "of the reduction type"; 667 } 668 669 for (YieldOp yieldOp : initializerRegion().getOps<YieldOp>()) { 670 if (yieldOp.results().size() != 1 || 671 yieldOp.results().getTypes()[0] != type()) 672 return emitOpError() << "expects initializer region to yield a value " 673 "of the reduction type"; 674 } 675 676 if (reductionRegion().empty()) 677 return emitOpError() << "expects non-empty reduction region"; 678 Block &reductionEntryBlock = reductionRegion().front(); 679 if (reductionEntryBlock.getNumArguments() != 2 || 680 reductionEntryBlock.getArgumentTypes()[0] != 681 reductionEntryBlock.getArgumentTypes()[1] || 682 reductionEntryBlock.getArgumentTypes()[0] != type()) 683 return emitOpError() << "expects reduction region with two arguments of " 684 "the reduction type"; 685 for (YieldOp yieldOp : reductionRegion().getOps<YieldOp>()) { 686 if (yieldOp.results().size() != 1 || 687 yieldOp.results().getTypes()[0] != type()) 688 return emitOpError() << "expects reduction region to yield a value " 689 "of the reduction type"; 690 } 691 692 if (atomicReductionRegion().empty()) 693 return success(); 694 695 Block &atomicReductionEntryBlock = atomicReductionRegion().front(); 696 if (atomicReductionEntryBlock.getNumArguments() != 2 || 697 atomicReductionEntryBlock.getArgumentTypes()[0] != 698 atomicReductionEntryBlock.getArgumentTypes()[1]) 699 return emitOpError() << "expects atomic reduction region with two " 700 "arguments of the same type"; 701 auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0] 702 .dyn_cast<PointerLikeType>(); 703 if (!ptrType || ptrType.getElementType() != type()) 704 return emitOpError() << "expects atomic reduction region arguments to " 705 "be accumulators containing the reduction type"; 706 return success(); 707 } 708 709 LogicalResult ReductionOp::verify() { 710 auto *op = (*this)->getParentWithTrait<ReductionClauseInterface::Trait>(); 711 if (!op) 712 return emitOpError() << "must be used within an operation supporting " 713 "reduction clause interface"; 714 while (op) { 715 for (const auto &var : 716 cast<ReductionClauseInterface>(op).getReductionVars()) 717 if (var == accumulator()) 718 return success(); 719 op = op->getParentWithTrait<ReductionClauseInterface::Trait>(); 720 } 721 return emitOpError() << "the accumulator is not used by the parent"; 722 } 723 724 //===----------------------------------------------------------------------===// 725 // TaskOp 726 //===----------------------------------------------------------------------===// 727 LogicalResult TaskOp::verify() { 728 return verifyReductionVarList(*this, in_reductions(), in_reduction_vars()); 729 } 730 731 //===----------------------------------------------------------------------===// 732 // TaskGroupOp 733 //===----------------------------------------------------------------------===// 734 LogicalResult TaskGroupOp::verify() { 735 return verifyReductionVarList(*this, task_reductions(), 736 task_reduction_vars()); 737 } 738 739 //===----------------------------------------------------------------------===// 740 // WsLoopOp 741 //===----------------------------------------------------------------------===// 742 743 void WsLoopOp::build(OpBuilder &builder, OperationState &state, 744 ValueRange lowerBound, ValueRange upperBound, 745 ValueRange step, ArrayRef<NamedAttribute> attributes) { 746 build(builder, state, lowerBound, upperBound, step, 747 /*linear_vars=*/ValueRange(), 748 /*linear_step_vars=*/ValueRange(), /*reduction_vars=*/ValueRange(), 749 /*reductions=*/nullptr, /*schedule_val=*/nullptr, 750 /*schedule_chunk_var=*/nullptr, /*schedule_modifier=*/nullptr, 751 /*simd_modifier=*/false, /*collapse_val=*/nullptr, /*nowait=*/false, 752 /*ordered_val=*/nullptr, /*order_val=*/nullptr, /*inclusive=*/false); 753 state.addAttributes(attributes); 754 } 755 756 LogicalResult WsLoopOp::verify() { 757 return verifyReductionVarList(*this, reductions(), reduction_vars()); 758 } 759 760 //===----------------------------------------------------------------------===// 761 // Verifier for critical construct (2.17.1) 762 //===----------------------------------------------------------------------===// 763 764 LogicalResult CriticalDeclareOp::verify() { 765 return verifySynchronizationHint(*this, hint_val()); 766 } 767 768 LogicalResult CriticalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 769 if (nameAttr()) { 770 SymbolRefAttr symbolRef = nameAttr(); 771 auto decl = symbolTable.lookupNearestSymbolFrom<CriticalDeclareOp>( 772 *this, symbolRef); 773 if (!decl) { 774 return emitOpError() << "expected symbol reference " << symbolRef 775 << " to point to a critical declaration"; 776 } 777 } 778 779 return success(); 780 } 781 782 //===----------------------------------------------------------------------===// 783 // Verifier for ordered construct 784 //===----------------------------------------------------------------------===// 785 786 LogicalResult OrderedOp::verify() { 787 auto container = (*this)->getParentOfType<WsLoopOp>(); 788 if (!container || !container.ordered_valAttr() || 789 container.ordered_valAttr().getInt() == 0) 790 return emitOpError() << "ordered depend directive must be closely " 791 << "nested inside a worksharing-loop with ordered " 792 << "clause with parameter present"; 793 794 if (container.ordered_valAttr().getInt() != (int64_t)*num_loops_val()) 795 return emitOpError() << "number of variables in depend clause does not " 796 << "match number of iteration variables in the " 797 << "doacross loop"; 798 799 return success(); 800 } 801 802 LogicalResult OrderedRegionOp::verify() { 803 // TODO: The code generation for ordered simd directive is not supported yet. 804 if (simd()) 805 return failure(); 806 807 if (auto container = (*this)->getParentOfType<WsLoopOp>()) { 808 if (!container.ordered_valAttr() || 809 container.ordered_valAttr().getInt() != 0) 810 return emitOpError() << "ordered region must be closely nested inside " 811 << "a worksharing-loop region with an ordered " 812 << "clause without parameter present"; 813 } 814 815 return success(); 816 } 817 818 //===----------------------------------------------------------------------===// 819 // Verifier for AtomicReadOp 820 //===----------------------------------------------------------------------===// 821 822 LogicalResult AtomicReadOp::verify() { 823 if (auto mo = memory_order_val()) { 824 if (*mo == ClauseMemoryOrderKind::Acq_rel || 825 *mo == ClauseMemoryOrderKind::Release) { 826 return emitError( 827 "memory-order must not be acq_rel or release for atomic reads"); 828 } 829 } 830 if (x() == v()) 831 return emitError( 832 "read and write must not be to the same location for atomic reads"); 833 return verifySynchronizationHint(*this, hint_val()); 834 } 835 836 //===----------------------------------------------------------------------===// 837 // Verifier for AtomicWriteOp 838 //===----------------------------------------------------------------------===// 839 840 LogicalResult AtomicWriteOp::verify() { 841 if (auto mo = memory_order_val()) { 842 if (*mo == ClauseMemoryOrderKind::Acq_rel || 843 *mo == ClauseMemoryOrderKind::Acquire) { 844 return emitError( 845 "memory-order must not be acq_rel or acquire for atomic writes"); 846 } 847 } 848 if (address().getType().cast<PointerLikeType>().getElementType() != 849 value().getType()) 850 return emitError("address must dereference to value type"); 851 return verifySynchronizationHint(*this, hint_val()); 852 } 853 854 //===----------------------------------------------------------------------===// 855 // Verifier for AtomicUpdateOp 856 //===----------------------------------------------------------------------===// 857 858 LogicalResult AtomicUpdateOp::verify() { 859 if (auto mo = memory_order_val()) { 860 if (*mo == ClauseMemoryOrderKind::Acq_rel || 861 *mo == ClauseMemoryOrderKind::Acquire) { 862 return emitError( 863 "memory-order must not be acq_rel or acquire for atomic updates"); 864 } 865 } 866 867 if (x().getType().cast<PointerLikeType>().getElementType() != 868 region().getArgument(0).getType()) { 869 return emitError("the type of the operand must be a pointer type whose " 870 "element type is the same as that of the region argument"); 871 } 872 873 return verifySynchronizationHint(*this, hint_val()); 874 } 875 876 LogicalResult AtomicUpdateOp::verifyRegions() { 877 if (region().getNumArguments() != 1) 878 return emitError("the region must accept exactly one argument"); 879 880 if (region().front().getOperations().size() < 2) 881 return emitError() << "the update region must have at least two operations " 882 "(binop and terminator)"; 883 884 YieldOp yieldOp = *region().getOps<YieldOp>().begin(); 885 886 if (yieldOp.results().size() != 1) 887 return emitError("only updated value must be returned"); 888 if (yieldOp.results().front().getType() != region().getArgument(0).getType()) 889 return emitError("input and yielded value must have the same type"); 890 return success(); 891 } 892 893 //===----------------------------------------------------------------------===// 894 // Verifier for AtomicCaptureOp 895 //===----------------------------------------------------------------------===// 896 897 Operation *AtomicCaptureOp::getFirstOp() { 898 return &getRegion().front().getOperations().front(); 899 } 900 901 Operation *AtomicCaptureOp::getSecondOp() { 902 auto &ops = getRegion().front().getOperations(); 903 return ops.getNextNode(ops.front()); 904 } 905 906 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() { 907 if (auto op = dyn_cast<AtomicReadOp>(getFirstOp())) 908 return op; 909 return dyn_cast<AtomicReadOp>(getSecondOp()); 910 } 911 912 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() { 913 if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp())) 914 return op; 915 return dyn_cast<AtomicWriteOp>(getSecondOp()); 916 } 917 918 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() { 919 if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp())) 920 return op; 921 return dyn_cast<AtomicUpdateOp>(getSecondOp()); 922 } 923 924 LogicalResult AtomicCaptureOp::verify() { 925 return verifySynchronizationHint(*this, hint_val()); 926 } 927 928 LogicalResult AtomicCaptureOp::verifyRegions() { 929 Block::OpListType &ops = region().front().getOperations(); 930 if (ops.size() != 3) 931 return emitError() 932 << "expected three operations in omp.atomic.capture region (one " 933 "terminator, and two atomic ops)"; 934 auto &firstOp = ops.front(); 935 auto &secondOp = *ops.getNextNode(firstOp); 936 auto firstReadStmt = dyn_cast<AtomicReadOp>(firstOp); 937 auto firstUpdateStmt = dyn_cast<AtomicUpdateOp>(firstOp); 938 auto secondReadStmt = dyn_cast<AtomicReadOp>(secondOp); 939 auto secondUpdateStmt = dyn_cast<AtomicUpdateOp>(secondOp); 940 auto secondWriteStmt = dyn_cast<AtomicWriteOp>(secondOp); 941 942 if (!((firstUpdateStmt && secondReadStmt) || 943 (firstReadStmt && secondUpdateStmt) || 944 (firstReadStmt && secondWriteStmt))) 945 return ops.front().emitError() 946 << "invalid sequence of operations in the capture region"; 947 if (firstUpdateStmt && secondReadStmt && 948 firstUpdateStmt.x() != secondReadStmt.x()) 949 return firstUpdateStmt.emitError() 950 << "updated variable in omp.atomic.update must be captured in " 951 "second operation"; 952 if (firstReadStmt && secondUpdateStmt && 953 firstReadStmt.x() != secondUpdateStmt.x()) 954 return firstReadStmt.emitError() 955 << "captured variable in omp.atomic.read must be updated in second " 956 "operation"; 957 if (firstReadStmt && secondWriteStmt && 958 firstReadStmt.x() != secondWriteStmt.address()) 959 return firstReadStmt.emitError() 960 << "captured variable in omp.atomic.read must be updated in " 961 "second operation"; 962 963 if (getFirstOp()->getAttr("hint_val") || getSecondOp()->getAttr("hint_val")) 964 return emitOpError( 965 "operations inside capture region must not have hint clause"); 966 967 if (getFirstOp()->getAttr("memory_order_val") || 968 getSecondOp()->getAttr("memory_order_val")) 969 return emitOpError( 970 "operations inside capture region must not have memory_order clause"); 971 return success(); 972 } 973 974 //===----------------------------------------------------------------------===// 975 // Verifier for CancelOp 976 //===----------------------------------------------------------------------===// 977 978 LogicalResult CancelOp::verify() { 979 ClauseCancellationConstructType cct = cancellation_construct_type_val(); 980 Operation *parentOp = (*this)->getParentOp(); 981 982 if (!parentOp) { 983 return emitOpError() << "must be used within a region supporting " 984 "cancel directive"; 985 } 986 987 if ((cct == ClauseCancellationConstructType::Parallel) && 988 !isa<ParallelOp>(parentOp)) { 989 return emitOpError() << "cancel parallel must appear " 990 << "inside a parallel region"; 991 } 992 if (cct == ClauseCancellationConstructType::Loop) { 993 if (!isa<WsLoopOp>(parentOp)) { 994 return emitOpError() << "cancel loop must appear " 995 << "inside a worksharing-loop region"; 996 } 997 if (cast<WsLoopOp>(parentOp).nowaitAttr()) { 998 return emitError() << "A worksharing construct that is canceled " 999 << "must not have a nowait clause"; 1000 } 1001 if (cast<WsLoopOp>(parentOp).ordered_valAttr()) { 1002 return emitError() << "A worksharing construct that is canceled " 1003 << "must not have an ordered clause"; 1004 } 1005 1006 } else if (cct == ClauseCancellationConstructType::Sections) { 1007 if (!(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) { 1008 return emitOpError() << "cancel sections must appear " 1009 << "inside a sections region"; 1010 } 1011 if (isa_and_nonnull<SectionsOp>(parentOp->getParentOp()) && 1012 cast<SectionsOp>(parentOp->getParentOp()).nowaitAttr()) { 1013 return emitError() << "A sections construct that is canceled " 1014 << "must not have a nowait clause"; 1015 } 1016 } 1017 // TODO : Add more when we support taskgroup. 1018 return success(); 1019 } 1020 //===----------------------------------------------------------------------===// 1021 // Verifier for CancelOp 1022 //===----------------------------------------------------------------------===// 1023 1024 LogicalResult CancellationPointOp::verify() { 1025 ClauseCancellationConstructType cct = cancellation_construct_type_val(); 1026 Operation *parentOp = (*this)->getParentOp(); 1027 1028 if (!parentOp) { 1029 return emitOpError() << "must be used within a region supporting " 1030 "cancellation point directive"; 1031 } 1032 1033 if ((cct == ClauseCancellationConstructType::Parallel) && 1034 !(isa<ParallelOp>(parentOp))) { 1035 return emitOpError() << "cancellation point parallel must appear " 1036 << "inside a parallel region"; 1037 } 1038 if ((cct == ClauseCancellationConstructType::Loop) && 1039 !isa<WsLoopOp>(parentOp)) { 1040 return emitOpError() << "cancellation point loop must appear " 1041 << "inside a worksharing-loop region"; 1042 } 1043 if ((cct == ClauseCancellationConstructType::Sections) && 1044 !(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) { 1045 return emitOpError() << "cancellation point sections must appear " 1046 << "inside a sections region"; 1047 } 1048 // TODO : Add more when we support taskgroup. 1049 return success(); 1050 } 1051 1052 #define GET_ATTRDEF_CLASSES 1053 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" 1054 1055 #define GET_OP_CLASSES 1056 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 1057