1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the OpenMP dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 15 #include "mlir/IR/Attributes.h" 16 #include "mlir/IR/DialectImplementation.h" 17 #include "mlir/IR/OpImplementation.h" 18 #include "mlir/IR/OperationSupport.h" 19 20 #include "llvm/ADT/BitVector.h" 21 #include "llvm/ADT/SmallString.h" 22 #include "llvm/ADT/StringExtras.h" 23 #include "llvm/ADT/StringRef.h" 24 #include "llvm/ADT/StringSwitch.h" 25 #include "llvm/ADT/TypeSwitch.h" 26 #include <cstddef> 27 28 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc" 29 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc" 30 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc" 31 32 using namespace mlir; 33 using namespace mlir::omp; 34 35 namespace { 36 /// Model for pointer-like types that already provide a `getElementType` method. 37 template <typename T> 38 struct PointerLikeModel 39 : public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> { 40 Type getElementType(Type pointer) const { 41 return pointer.cast<T>().getElementType(); 42 } 43 }; 44 } // namespace 45 46 void OpenMPDialect::initialize() { 47 addOperations< 48 #define GET_OP_LIST 49 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 50 >(); 51 addAttributes< 52 #define GET_ATTRDEF_LIST 53 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" 54 >(); 55 56 LLVM::LLVMPointerType::attachInterface< 57 PointerLikeModel<LLVM::LLVMPointerType>>(*getContext()); 58 MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext()); 59 } 60 61 //===----------------------------------------------------------------------===// 62 // ParallelOp 63 //===----------------------------------------------------------------------===// 64 65 void ParallelOp::build(OpBuilder &builder, OperationState &state, 66 ArrayRef<NamedAttribute> attributes) { 67 ParallelOp::build( 68 builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr, 69 /*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(), 70 /*proc_bind_val=*/nullptr); 71 state.addAttributes(attributes); 72 } 73 74 //===----------------------------------------------------------------------===// 75 // Parser and printer for Allocate Clause 76 //===----------------------------------------------------------------------===// 77 78 /// Parse an allocate clause with allocators and a list of operands with types. 79 /// 80 /// allocate-operand-list :: = allocate-operand | 81 /// allocator-operand `,` allocate-operand-list 82 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type 83 /// ssa-id-and-type ::= ssa-id `:` type 84 static ParseResult parseAllocateAndAllocator( 85 OpAsmParser &parser, 86 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operandsAllocate, 87 SmallVectorImpl<Type> &typesAllocate, 88 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operandsAllocator, 89 SmallVectorImpl<Type> &typesAllocator) { 90 91 return parser.parseCommaSeparatedList([&]() -> ParseResult { 92 OpAsmParser::UnresolvedOperand operand; 93 Type type; 94 if (parser.parseOperand(operand) || parser.parseColonType(type)) 95 return failure(); 96 operandsAllocator.push_back(operand); 97 typesAllocator.push_back(type); 98 if (parser.parseArrow()) 99 return failure(); 100 if (parser.parseOperand(operand) || parser.parseColonType(type)) 101 return failure(); 102 103 operandsAllocate.push_back(operand); 104 typesAllocate.push_back(type); 105 return success(); 106 }); 107 } 108 109 /// Print allocate clause 110 static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, 111 OperandRange varsAllocate, 112 TypeRange typesAllocate, 113 OperandRange varsAllocator, 114 TypeRange typesAllocator) { 115 for (unsigned i = 0; i < varsAllocate.size(); ++i) { 116 std::string separator = i == varsAllocate.size() - 1 ? "" : ", "; 117 p << varsAllocator[i] << " : " << typesAllocator[i] << " -> "; 118 p << varsAllocate[i] << " : " << typesAllocate[i] << separator; 119 } 120 } 121 122 //===----------------------------------------------------------------------===// 123 // Parser and printer for a clause attribute (StringEnumAttr) 124 //===----------------------------------------------------------------------===// 125 126 template <typename ClauseAttr> 127 static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) { 128 using ClauseT = decltype(std::declval<ClauseAttr>().getValue()); 129 StringRef enumStr; 130 SMLoc loc = parser.getCurrentLocation(); 131 if (parser.parseKeyword(&enumStr)) 132 return failure(); 133 if (Optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) { 134 attr = ClauseAttr::get(parser.getContext(), *enumValue); 135 return success(); 136 } 137 return parser.emitError(loc, "invalid clause value: '") << enumStr << "'"; 138 } 139 140 template <typename ClauseAttr> 141 void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) { 142 p << stringifyEnum(attr.getValue()); 143 } 144 145 LogicalResult ParallelOp::verify() { 146 if (allocate_vars().size() != allocators_vars().size()) 147 return emitError( 148 "expected equal sizes for allocate and allocator variables"); 149 return success(); 150 } 151 152 //===----------------------------------------------------------------------===// 153 // Parser and printer for Linear Clause 154 //===----------------------------------------------------------------------===// 155 156 /// linear ::= `linear` `(` linear-list `)` 157 /// linear-list := linear-val | linear-val linear-list 158 /// linear-val := ssa-id-and-type `=` ssa-id-and-type 159 static ParseResult 160 parseLinearClause(OpAsmParser &parser, 161 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars, 162 SmallVectorImpl<Type> &types, 163 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &stepVars) { 164 do { 165 OpAsmParser::UnresolvedOperand var; 166 Type type; 167 OpAsmParser::UnresolvedOperand stepVar; 168 if (parser.parseOperand(var) || parser.parseEqual() || 169 parser.parseOperand(stepVar) || parser.parseColonType(type)) 170 return failure(); 171 172 vars.push_back(var); 173 types.push_back(type); 174 stepVars.push_back(stepVar); 175 } while (succeeded(parser.parseOptionalComma())); 176 return success(); 177 } 178 179 /// Print Linear Clause 180 static void printLinearClause(OpAsmPrinter &p, Operation *op, 181 ValueRange linearVars, TypeRange linearVarTypes, 182 ValueRange linearStepVars) { 183 size_t linearVarsSize = linearVars.size(); 184 for (unsigned i = 0; i < linearVarsSize; ++i) { 185 std::string separator = i == linearVarsSize - 1 ? "" : ", "; 186 p << linearVars[i]; 187 if (linearStepVars.size() > i) 188 p << " = " << linearStepVars[i]; 189 p << " : " << linearVars[i].getType() << separator; 190 } 191 } 192 193 //===----------------------------------------------------------------------===// 194 // Parser, printer and verifier for Schedule Clause 195 //===----------------------------------------------------------------------===// 196 197 static ParseResult 198 verifyScheduleModifiers(OpAsmParser &parser, 199 SmallVectorImpl<SmallString<12>> &modifiers) { 200 if (modifiers.size() > 2) 201 return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)"; 202 for (const auto &mod : modifiers) { 203 // Translate the string. If it has no value, then it was not a valid 204 // modifier! 205 auto symbol = symbolizeScheduleModifier(mod); 206 if (!symbol.hasValue()) 207 return parser.emitError(parser.getNameLoc()) 208 << " unknown modifier type: " << mod; 209 } 210 211 // If we have one modifier that is "simd", then stick a "none" modiifer in 212 // index 0. 213 if (modifiers.size() == 1) { 214 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) { 215 modifiers.push_back(modifiers[0]); 216 modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none); 217 } 218 } else if (modifiers.size() == 2) { 219 // If there are two modifier: 220 // First modifier should not be simd, second one should be simd 221 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd || 222 symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd) 223 return parser.emitError(parser.getNameLoc()) 224 << " incorrect modifier order"; 225 } 226 return success(); 227 } 228 229 /// schedule ::= `schedule` `(` sched-list `)` 230 /// sched-list ::= sched-val | sched-val sched-list | 231 /// sched-val `,` sched-modifier 232 /// sched-val ::= sched-with-chunk | sched-wo-chunk 233 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)? 234 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided` 235 /// sched-wo-chunk ::= `auto` | `runtime` 236 /// sched-modifier ::= sched-mod-val | sched-mod-val `,` sched-mod-val 237 /// sched-mod-val ::= `monotonic` | `nonmonotonic` | `simd` | `none` 238 static ParseResult parseScheduleClause( 239 OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, 240 ScheduleModifierAttr &schedule_modifier, UnitAttr &simdModifier, 241 Optional<OpAsmParser::UnresolvedOperand> &chunkSize, Type &chunkType) { 242 StringRef keyword; 243 if (parser.parseKeyword(&keyword)) 244 return failure(); 245 llvm::Optional<mlir::omp::ClauseScheduleKind> schedule = 246 symbolizeClauseScheduleKind(keyword); 247 if (!schedule) 248 return parser.emitError(parser.getNameLoc()) << " expected schedule kind"; 249 250 scheduleAttr = ClauseScheduleKindAttr::get(parser.getContext(), *schedule); 251 switch (*schedule) { 252 case ClauseScheduleKind::Static: 253 case ClauseScheduleKind::Dynamic: 254 case ClauseScheduleKind::Guided: 255 if (succeeded(parser.parseOptionalEqual())) { 256 chunkSize = OpAsmParser::UnresolvedOperand{}; 257 if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType)) 258 return failure(); 259 } else { 260 chunkSize = llvm::NoneType::None; 261 } 262 break; 263 case ClauseScheduleKind::Auto: 264 case ClauseScheduleKind::Runtime: 265 chunkSize = llvm::NoneType::None; 266 } 267 268 // If there is a comma, we have one or more modifiers.. 269 SmallVector<SmallString<12>> modifiers; 270 while (succeeded(parser.parseOptionalComma())) { 271 StringRef mod; 272 if (parser.parseKeyword(&mod)) 273 return failure(); 274 modifiers.push_back(mod); 275 } 276 277 if (verifyScheduleModifiers(parser, modifiers)) 278 return failure(); 279 280 if (!modifiers.empty()) { 281 SMLoc loc = parser.getCurrentLocation(); 282 if (Optional<ScheduleModifier> mod = 283 symbolizeScheduleModifier(modifiers[0])) { 284 schedule_modifier = ScheduleModifierAttr::get(parser.getContext(), *mod); 285 } else { 286 return parser.emitError(loc, "invalid schedule modifier"); 287 } 288 // Only SIMD attribute is allowed here! 289 if (modifiers.size() > 1) { 290 assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd); 291 simdModifier = UnitAttr::get(parser.getBuilder().getContext()); 292 } 293 } 294 295 return success(); 296 } 297 298 /// Print schedule clause 299 static void printScheduleClause(OpAsmPrinter &p, Operation *op, 300 ClauseScheduleKindAttr schedAttr, 301 ScheduleModifierAttr modifier, UnitAttr simd, 302 Value scheduleChunkVar, 303 Type scheduleChunkType) { 304 p << stringifyClauseScheduleKind(schedAttr.getValue()); 305 if (scheduleChunkVar) 306 p << " = " << scheduleChunkVar << " : " << scheduleChunkVar.getType(); 307 if (modifier) 308 p << ", " << stringifyScheduleModifier(modifier.getValue()); 309 if (simd) 310 p << ", simd"; 311 } 312 313 //===----------------------------------------------------------------------===// 314 // Parser, printer and verifier for ReductionVarList 315 //===----------------------------------------------------------------------===// 316 317 /// reduction-entry-list ::= reduction-entry 318 /// | reduction-entry-list `,` reduction-entry 319 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type 320 static ParseResult 321 parseReductionVarList(OpAsmParser &parser, 322 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands, 323 SmallVectorImpl<Type> &types, 324 ArrayAttr &redcuctionSymbols) { 325 SmallVector<SymbolRefAttr> reductionVec; 326 do { 327 if (parser.parseAttribute(reductionVec.emplace_back()) || 328 parser.parseArrow() || parser.parseOperand(operands.emplace_back()) || 329 parser.parseColonType(types.emplace_back())) 330 return failure(); 331 } while (succeeded(parser.parseOptionalComma())); 332 SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end()); 333 redcuctionSymbols = ArrayAttr::get(parser.getContext(), reductions); 334 return success(); 335 } 336 337 /// Print Reduction clause 338 static void printReductionVarList(OpAsmPrinter &p, Operation *op, 339 OperandRange reductionVars, 340 TypeRange reductionTypes, 341 Optional<ArrayAttr> reductions) { 342 for (unsigned i = 0, e = reductions->size(); i < e; ++i) { 343 if (i != 0) 344 p << ", "; 345 p << (*reductions)[i] << " -> " << reductionVars[i] << " : " 346 << reductionVars[i].getType(); 347 } 348 } 349 350 /// Verifies Reduction Clause 351 static LogicalResult verifyReductionVarList(Operation *op, 352 Optional<ArrayAttr> reductions, 353 OperandRange reductionVars) { 354 if (!reductionVars.empty()) { 355 if (!reductions || reductions->size() != reductionVars.size()) 356 return op->emitOpError() 357 << "expected as many reduction symbol references " 358 "as reduction variables"; 359 } else { 360 if (reductions) 361 return op->emitOpError() << "unexpected reduction symbol references"; 362 return success(); 363 } 364 365 // TODO: The followings should be done in 366 // SymbolUserOpInterface::verifySymbolUses. 367 DenseSet<Value> accumulators; 368 for (auto args : llvm::zip(reductionVars, *reductions)) { 369 Value accum = std::get<0>(args); 370 371 if (!accumulators.insert(accum).second) 372 return op->emitOpError() << "accumulator variable used more than once"; 373 374 Type varType = accum.getType().cast<PointerLikeType>(); 375 auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>(); 376 auto decl = 377 SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef); 378 if (!decl) 379 return op->emitOpError() << "expected symbol reference " << symbolRef 380 << " to point to a reduction declaration"; 381 382 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType) 383 return op->emitOpError() 384 << "expected accumulator (" << varType 385 << ") to be the same type as reduction declaration (" 386 << decl.getAccumulatorType() << ")"; 387 } 388 389 return success(); 390 } 391 392 //===----------------------------------------------------------------------===// 393 // Parser, printer and verifier for Synchronization Hint (2.17.12) 394 //===----------------------------------------------------------------------===// 395 396 /// Parses a Synchronization Hint clause. The value of hint is an integer 397 /// which is a combination of different hints from `omp_sync_hint_t`. 398 /// 399 /// hint-clause = `hint` `(` hint-value `)` 400 static ParseResult parseSynchronizationHint(OpAsmParser &parser, 401 IntegerAttr &hintAttr) { 402 StringRef hintKeyword; 403 int64_t hint = 0; 404 do { 405 if (failed(parser.parseKeyword(&hintKeyword))) 406 return failure(); 407 if (hintKeyword == "uncontended") 408 hint |= 1; 409 else if (hintKeyword == "contended") 410 hint |= 2; 411 else if (hintKeyword == "nonspeculative") 412 hint |= 4; 413 else if (hintKeyword == "speculative") 414 hint |= 8; 415 else 416 return parser.emitError(parser.getCurrentLocation()) 417 << hintKeyword << " is not a valid hint"; 418 } while (succeeded(parser.parseOptionalComma())); 419 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint); 420 return success(); 421 } 422 423 /// Prints a Synchronization Hint clause 424 static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, 425 IntegerAttr hintAttr) { 426 int64_t hint = hintAttr.getInt(); 427 428 if (hint == 0) 429 return; 430 431 // Helper function to get n-th bit from the right end of `value` 432 auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; 433 434 bool uncontended = bitn(hint, 0); 435 bool contended = bitn(hint, 1); 436 bool nonspeculative = bitn(hint, 2); 437 bool speculative = bitn(hint, 3); 438 439 SmallVector<StringRef> hints; 440 if (uncontended) 441 hints.push_back("uncontended"); 442 if (contended) 443 hints.push_back("contended"); 444 if (nonspeculative) 445 hints.push_back("nonspeculative"); 446 if (speculative) 447 hints.push_back("speculative"); 448 449 llvm::interleaveComma(hints, p); 450 } 451 452 /// Verifies a synchronization hint clause 453 static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) { 454 455 // Helper function to get n-th bit from the right end of `value` 456 auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; 457 458 bool uncontended = bitn(hint, 0); 459 bool contended = bitn(hint, 1); 460 bool nonspeculative = bitn(hint, 2); 461 bool speculative = bitn(hint, 3); 462 463 if (uncontended && contended) 464 return op->emitOpError() << "the hints omp_sync_hint_uncontended and " 465 "omp_sync_hint_contended cannot be combined"; 466 if (nonspeculative && speculative) 467 return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and " 468 "omp_sync_hint_speculative cannot be combined."; 469 return success(); 470 } 471 472 //===----------------------------------------------------------------------===// 473 // Verifier for SectionsOp 474 //===----------------------------------------------------------------------===// 475 476 LogicalResult SectionsOp::verify() { 477 if (allocate_vars().size() != allocators_vars().size()) 478 return emitError( 479 "expected equal sizes for allocate and allocator variables"); 480 481 return verifyReductionVarList(*this, reductions(), reduction_vars()); 482 } 483 484 LogicalResult SectionsOp::verifyRegions() { 485 for (auto &inst : *region().begin()) { 486 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) { 487 return emitOpError() 488 << "expected omp.section op or terminator op inside region"; 489 } 490 } 491 492 return success(); 493 } 494 495 LogicalResult SingleOp::verify() { 496 // Check for allocate clause restrictions 497 if (allocate_vars().size() != allocators_vars().size()) 498 return emitError( 499 "expected equal sizes for allocate and allocator variables"); 500 501 return success(); 502 } 503 504 //===----------------------------------------------------------------------===// 505 // WsLoopOp 506 //===----------------------------------------------------------------------===// 507 508 /// loop-control ::= `(` ssa-id-list `)` `:` type `=` loop-bounds 509 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` inclusive? steps 510 /// steps := `step` `(`ssa-id-list`)` 511 ParseResult 512 parseWsLoopControl(OpAsmParser &parser, Region ®ion, 513 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &lowerBound, 514 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &upperBound, 515 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &steps, 516 SmallVectorImpl<Type> &loopVarTypes, UnitAttr &inclusive) { 517 // Parse an opening `(` followed by induction variables followed by `)` 518 SmallVector<OpAsmParser::UnresolvedOperand> ivs; 519 if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1, 520 OpAsmParser::Delimiter::Paren)) 521 return failure(); 522 523 size_t numIVs = ivs.size(); 524 Type loopVarType; 525 if (parser.parseColonType(loopVarType)) 526 return failure(); 527 528 // Parse loop bounds. 529 if (parser.parseEqual() || 530 parser.parseOperandList(lowerBound, numIVs, 531 OpAsmParser::Delimiter::Paren)) 532 return failure(); 533 if (parser.parseKeyword("to") || 534 parser.parseOperandList(upperBound, numIVs, 535 OpAsmParser::Delimiter::Paren)) 536 return failure(); 537 538 if (succeeded(parser.parseOptionalKeyword("inclusive"))) { 539 inclusive = UnitAttr::get(parser.getBuilder().getContext()); 540 } 541 542 // Parse step values. 543 if (parser.parseKeyword("step") || 544 parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren)) 545 return failure(); 546 547 // Now parse the body. 548 loopVarTypes = SmallVector<Type>(numIVs, loopVarType); 549 SmallVector<OpAsmParser::UnresolvedOperand> blockArgs(ivs); 550 if (parser.parseRegion(region, blockArgs, loopVarTypes)) 551 return failure(); 552 return success(); 553 } 554 555 void printWsLoopControl(OpAsmPrinter &p, Operation *op, Region ®ion, 556 ValueRange lowerBound, ValueRange upperBound, 557 ValueRange steps, TypeRange loopVarTypes, 558 UnitAttr inclusive) { 559 auto args = region.front().getArguments(); 560 p << " (" << args << ") : " << args[0].getType() << " = (" << lowerBound 561 << ") to (" << upperBound << ") "; 562 if (inclusive) 563 p << "inclusive "; 564 p << "step (" << steps << ") "; 565 p.printRegion(region, /*printEntryBlockArgs=*/false); 566 } 567 568 //===----------------------------------------------------------------------===// 569 // SimdLoopOp 570 //===----------------------------------------------------------------------===// 571 /// Parses an OpenMP Simd construct [2.9.3.1] 572 /// 573 /// simdloop ::= `omp.simdloop` loop-control clause-list 574 /// loop-control ::= `(` ssa-id-list `)` `:` type `=` loop-bounds 575 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` steps 576 /// steps := `step` `(`ssa-id-list`)` 577 /// clause-list ::= clause clause-list | empty 578 /// clause ::= TODO 579 ParseResult SimdLoopOp::parse(OpAsmParser &parser, OperationState &result) { 580 // Parse an opening `(` followed by induction variables followed by `)` 581 SmallVector<OpAsmParser::UnresolvedOperand> ivs; 582 if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1, 583 OpAsmParser::Delimiter::Paren)) 584 return failure(); 585 int numIVs = static_cast<int>(ivs.size()); 586 Type loopVarType; 587 if (parser.parseColonType(loopVarType)) 588 return failure(); 589 // Parse loop bounds. 590 SmallVector<OpAsmParser::UnresolvedOperand> lower; 591 if (parser.parseEqual() || 592 parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) || 593 parser.resolveOperands(lower, loopVarType, result.operands)) 594 return failure(); 595 SmallVector<OpAsmParser::UnresolvedOperand> upper; 596 if (parser.parseKeyword("to") || 597 parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) || 598 parser.resolveOperands(upper, loopVarType, result.operands)) 599 return failure(); 600 601 // Parse step values. 602 SmallVector<OpAsmParser::UnresolvedOperand> steps; 603 if (parser.parseKeyword("step") || 604 parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) || 605 parser.resolveOperands(steps, loopVarType, result.operands)) 606 return failure(); 607 608 SmallVector<int> segments{numIVs, numIVs, numIVs}; 609 // TODO: Add parseClauses() when we support clauses 610 result.addAttribute("operand_segment_sizes", 611 parser.getBuilder().getI32VectorAttr(segments)); 612 613 // Now parse the body. 614 Region *body = result.addRegion(); 615 SmallVector<Type> ivTypes(numIVs, loopVarType); 616 SmallVector<OpAsmParser::UnresolvedOperand> blockArgs(ivs); 617 if (parser.parseRegion(*body, blockArgs, ivTypes)) 618 return failure(); 619 return success(); 620 } 621 622 void SimdLoopOp::print(OpAsmPrinter &p) { 623 auto args = getRegion().front().getArguments(); 624 p << " (" << args << ") : " << args[0].getType() << " = (" << lowerBound() 625 << ") to (" << upperBound() << ") "; 626 p << "step (" << step() << ") "; 627 628 p.printRegion(region(), /*printEntryBlockArgs=*/false); 629 } 630 631 //===----------------------------------------------------------------------===// 632 // Verifier for Simd construct [2.9.3.1] 633 //===----------------------------------------------------------------------===// 634 635 LogicalResult SimdLoopOp::verify() { 636 if (this->lowerBound().empty()) { 637 return emitOpError() << "empty lowerbound for simd loop operation"; 638 } 639 return success(); 640 } 641 642 //===----------------------------------------------------------------------===// 643 // ReductionOp 644 //===----------------------------------------------------------------------===// 645 646 static ParseResult parseAtomicReductionRegion(OpAsmParser &parser, 647 Region ®ion) { 648 if (parser.parseOptionalKeyword("atomic")) 649 return success(); 650 return parser.parseRegion(region); 651 } 652 653 static void printAtomicReductionRegion(OpAsmPrinter &printer, 654 ReductionDeclareOp op, Region ®ion) { 655 if (region.empty()) 656 return; 657 printer << "atomic "; 658 printer.printRegion(region); 659 } 660 661 LogicalResult ReductionDeclareOp::verifyRegions() { 662 if (initializerRegion().empty()) 663 return emitOpError() << "expects non-empty initializer region"; 664 Block &initializerEntryBlock = initializerRegion().front(); 665 if (initializerEntryBlock.getNumArguments() != 1 || 666 initializerEntryBlock.getArgument(0).getType() != type()) { 667 return emitOpError() << "expects initializer region with one argument " 668 "of the reduction type"; 669 } 670 671 for (YieldOp yieldOp : initializerRegion().getOps<YieldOp>()) { 672 if (yieldOp.results().size() != 1 || 673 yieldOp.results().getTypes()[0] != type()) 674 return emitOpError() << "expects initializer region to yield a value " 675 "of the reduction type"; 676 } 677 678 if (reductionRegion().empty()) 679 return emitOpError() << "expects non-empty reduction region"; 680 Block &reductionEntryBlock = reductionRegion().front(); 681 if (reductionEntryBlock.getNumArguments() != 2 || 682 reductionEntryBlock.getArgumentTypes()[0] != 683 reductionEntryBlock.getArgumentTypes()[1] || 684 reductionEntryBlock.getArgumentTypes()[0] != type()) 685 return emitOpError() << "expects reduction region with two arguments of " 686 "the reduction type"; 687 for (YieldOp yieldOp : reductionRegion().getOps<YieldOp>()) { 688 if (yieldOp.results().size() != 1 || 689 yieldOp.results().getTypes()[0] != type()) 690 return emitOpError() << "expects reduction region to yield a value " 691 "of the reduction type"; 692 } 693 694 if (atomicReductionRegion().empty()) 695 return success(); 696 697 Block &atomicReductionEntryBlock = atomicReductionRegion().front(); 698 if (atomicReductionEntryBlock.getNumArguments() != 2 || 699 atomicReductionEntryBlock.getArgumentTypes()[0] != 700 atomicReductionEntryBlock.getArgumentTypes()[1]) 701 return emitOpError() << "expects atomic reduction region with two " 702 "arguments of the same type"; 703 auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0] 704 .dyn_cast<PointerLikeType>(); 705 if (!ptrType || ptrType.getElementType() != type()) 706 return emitOpError() << "expects atomic reduction region arguments to " 707 "be accumulators containing the reduction type"; 708 return success(); 709 } 710 711 LogicalResult ReductionOp::verify() { 712 // TODO: generalize this to an op interface when there is more than one op 713 // that supports reductions. 714 auto container = (*this)->getParentOfType<WsLoopOp>(); 715 for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i) 716 if (container.reduction_vars()[i] == accumulator()) 717 return success(); 718 719 return emitOpError() << "the accumulator is not used by the parent"; 720 } 721 722 //===----------------------------------------------------------------------===// 723 // WsLoopOp 724 //===----------------------------------------------------------------------===// 725 726 void WsLoopOp::build(OpBuilder &builder, OperationState &state, 727 ValueRange lowerBound, ValueRange upperBound, 728 ValueRange step, ArrayRef<NamedAttribute> attributes) { 729 build(builder, state, lowerBound, upperBound, step, 730 /*linear_vars=*/ValueRange(), 731 /*linear_step_vars=*/ValueRange(), /*reduction_vars=*/ValueRange(), 732 /*reductions=*/nullptr, /*schedule_val=*/nullptr, 733 /*schedule_chunk_var=*/nullptr, /*schedule_modifier=*/nullptr, 734 /*simd_modifier=*/false, /*collapse_val=*/nullptr, /*nowait=*/false, 735 /*ordered_val=*/nullptr, /*order_val=*/nullptr, /*inclusive=*/false); 736 state.addAttributes(attributes); 737 } 738 739 LogicalResult WsLoopOp::verify() { 740 return verifyReductionVarList(*this, reductions(), reduction_vars()); 741 } 742 743 //===----------------------------------------------------------------------===// 744 // Verifier for critical construct (2.17.1) 745 //===----------------------------------------------------------------------===// 746 747 LogicalResult CriticalDeclareOp::verify() { 748 return verifySynchronizationHint(*this, hint_val()); 749 } 750 751 LogicalResult 752 CriticalOp::verifySymbolUses(SymbolTableCollection &symbol_table) { 753 if (nameAttr()) { 754 SymbolRefAttr symbolRef = nameAttr(); 755 auto decl = symbol_table.lookupNearestSymbolFrom<CriticalDeclareOp>( 756 *this, symbolRef); 757 if (!decl) { 758 return emitOpError() << "expected symbol reference " << symbolRef 759 << " to point to a critical declaration"; 760 } 761 } 762 763 return success(); 764 } 765 766 //===----------------------------------------------------------------------===// 767 // Verifier for ordered construct 768 //===----------------------------------------------------------------------===// 769 770 LogicalResult OrderedOp::verify() { 771 auto container = (*this)->getParentOfType<WsLoopOp>(); 772 if (!container || !container.ordered_valAttr() || 773 container.ordered_valAttr().getInt() == 0) 774 return emitOpError() << "ordered depend directive must be closely " 775 << "nested inside a worksharing-loop with ordered " 776 << "clause with parameter present"; 777 778 if (container.ordered_valAttr().getInt() != 779 (int64_t)num_loops_val().getValue()) 780 return emitOpError() << "number of variables in depend clause does not " 781 << "match number of iteration variables in the " 782 << "doacross loop"; 783 784 return success(); 785 } 786 787 LogicalResult OrderedRegionOp::verify() { 788 // TODO: The code generation for ordered simd directive is not supported yet. 789 if (simd()) 790 return failure(); 791 792 if (auto container = (*this)->getParentOfType<WsLoopOp>()) { 793 if (!container.ordered_valAttr() || 794 container.ordered_valAttr().getInt() != 0) 795 return emitOpError() << "ordered region must be closely nested inside " 796 << "a worksharing-loop region with an ordered " 797 << "clause without parameter present"; 798 } 799 800 return success(); 801 } 802 803 //===----------------------------------------------------------------------===// 804 // Verifier for AtomicReadOp 805 //===----------------------------------------------------------------------===// 806 807 LogicalResult AtomicReadOp::verify() { 808 if (auto mo = memory_order_val()) { 809 if (*mo == ClauseMemoryOrderKind::Acq_rel || 810 *mo == ClauseMemoryOrderKind::Release) { 811 return emitError( 812 "memory-order must not be acq_rel or release for atomic reads"); 813 } 814 } 815 if (x() == v()) 816 return emitError( 817 "read and write must not be to the same location for atomic reads"); 818 return verifySynchronizationHint(*this, hint_val()); 819 } 820 821 //===----------------------------------------------------------------------===// 822 // Verifier for AtomicWriteOp 823 //===----------------------------------------------------------------------===// 824 825 LogicalResult AtomicWriteOp::verify() { 826 if (auto mo = memory_order_val()) { 827 if (*mo == ClauseMemoryOrderKind::Acq_rel || 828 *mo == ClauseMemoryOrderKind::Acquire) { 829 return emitError( 830 "memory-order must not be acq_rel or acquire for atomic writes"); 831 } 832 } 833 return verifySynchronizationHint(*this, hint_val()); 834 } 835 836 //===----------------------------------------------------------------------===// 837 // Verifier for AtomicUpdateOp 838 //===----------------------------------------------------------------------===// 839 840 LogicalResult AtomicUpdateOp::verify() { 841 if (auto mo = memory_order_val()) { 842 if (*mo == ClauseMemoryOrderKind::Acq_rel || 843 *mo == ClauseMemoryOrderKind::Acquire) { 844 return emitError( 845 "memory-order must not be acq_rel or acquire for atomic updates"); 846 } 847 } 848 849 if (x().getType().cast<PointerLikeType>().getElementType() != 850 region().getArgument(0).getType()) { 851 return emitError("the type of the operand must be a pointer type whose " 852 "element type is the same as that of the region argument"); 853 } 854 855 return success(); 856 } 857 858 LogicalResult AtomicUpdateOp::verifyRegions() { 859 if (region().getNumArguments() != 1) 860 return emitError("the region must accept exactly one argument"); 861 862 if (region().front().getOperations().size() < 2) 863 return emitError() << "the update region must have at least two operations " 864 "(binop and terminator)"; 865 866 YieldOp yieldOp = *region().getOps<YieldOp>().begin(); 867 868 if (yieldOp.results().size() != 1) 869 return emitError("only updated value must be returned"); 870 if (yieldOp.results().front().getType() != region().getArgument(0).getType()) 871 return emitError("input and yielded value must have the same type"); 872 return success(); 873 } 874 875 //===----------------------------------------------------------------------===// 876 // Verifier for AtomicCaptureOp 877 //===----------------------------------------------------------------------===// 878 879 Operation *AtomicCaptureOp::getFirstOp() { 880 return &getRegion().front().getOperations().front(); 881 } 882 883 Operation *AtomicCaptureOp::getSecondOp() { 884 auto &ops = getRegion().front().getOperations(); 885 return ops.getNextNode(ops.front()); 886 } 887 888 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() { 889 if (auto op = dyn_cast<AtomicReadOp>(getFirstOp())) 890 return op; 891 return dyn_cast<AtomicReadOp>(getSecondOp()); 892 } 893 894 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() { 895 if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp())) 896 return op; 897 return dyn_cast<AtomicWriteOp>(getSecondOp()); 898 } 899 900 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() { 901 if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp())) 902 return op; 903 return dyn_cast<AtomicUpdateOp>(getSecondOp()); 904 } 905 906 LogicalResult AtomicCaptureOp::verifyRegions() { 907 Block::OpListType &ops = region().front().getOperations(); 908 if (ops.size() != 3) 909 return emitError() 910 << "expected three operations in omp.atomic.capture region (one " 911 "terminator, and two atomic ops)"; 912 auto &firstOp = ops.front(); 913 auto &secondOp = *ops.getNextNode(firstOp); 914 auto firstReadStmt = dyn_cast<AtomicReadOp>(firstOp); 915 auto firstUpdateStmt = dyn_cast<AtomicUpdateOp>(firstOp); 916 auto secondReadStmt = dyn_cast<AtomicReadOp>(secondOp); 917 auto secondUpdateStmt = dyn_cast<AtomicUpdateOp>(secondOp); 918 auto secondWriteStmt = dyn_cast<AtomicWriteOp>(secondOp); 919 920 if (!((firstUpdateStmt && secondReadStmt) || 921 (firstReadStmt && secondUpdateStmt) || 922 (firstReadStmt && secondWriteStmt))) 923 return ops.front().emitError() 924 << "invalid sequence of operations in the capture region"; 925 if (firstUpdateStmt && secondReadStmt && 926 firstUpdateStmt.x() != secondReadStmt.x()) 927 return firstUpdateStmt.emitError() 928 << "updated variable in omp.atomic.update must be captured in " 929 "second operation"; 930 if (firstReadStmt && secondUpdateStmt && 931 firstReadStmt.x() != secondUpdateStmt.x()) 932 return firstReadStmt.emitError() 933 << "captured variable in omp.atomic.read must be updated in second " 934 "operation"; 935 if (firstReadStmt && secondWriteStmt && 936 firstReadStmt.x() != secondWriteStmt.address()) 937 return firstReadStmt.emitError() 938 << "captured variable in omp.atomic.read must be updated in " 939 "second operation"; 940 return success(); 941 } 942 943 #define GET_ATTRDEF_CLASSES 944 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" 945 946 #define GET_OP_CLASSES 947 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 948