1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the OpenMP dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/IR/Attributes.h" 17 #include "mlir/IR/OpImplementation.h" 18 #include "mlir/IR/OperationSupport.h" 19 20 #include "llvm/ADT/BitVector.h" 21 #include "llvm/ADT/SmallString.h" 22 #include "llvm/ADT/StringExtras.h" 23 #include "llvm/ADT/StringRef.h" 24 #include "llvm/ADT/StringSwitch.h" 25 #include <cstddef> 26 27 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc" 28 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc" 29 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc" 30 31 using namespace mlir; 32 using namespace mlir::omp; 33 34 namespace { 35 /// Model for pointer-like types that already provide a `getElementType` method. 36 template <typename T> 37 struct PointerLikeModel 38 : public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> { 39 Type getElementType(Type pointer) const { 40 return pointer.cast<T>().getElementType(); 41 } 42 }; 43 } // end namespace 44 45 void OpenMPDialect::initialize() { 46 addOperations< 47 #define GET_OP_LIST 48 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 49 >(); 50 51 LLVM::LLVMPointerType::attachInterface< 52 PointerLikeModel<LLVM::LLVMPointerType>>(*getContext()); 53 MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext()); 54 } 55 56 //===----------------------------------------------------------------------===// 57 // ParallelOp 58 //===----------------------------------------------------------------------===// 59 60 void ParallelOp::build(OpBuilder &builder, OperationState &state, 61 ArrayRef<NamedAttribute> attributes) { 62 ParallelOp::build( 63 builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr, 64 /*default_val=*/nullptr, /*private_vars=*/ValueRange(), 65 /*firstprivate_vars=*/ValueRange(), /*shared_vars=*/ValueRange(), 66 /*copyin_vars=*/ValueRange(), /*allocate_vars=*/ValueRange(), 67 /*allocators_vars=*/ValueRange(), /*proc_bind_val=*/nullptr); 68 state.addAttributes(attributes); 69 } 70 71 //===----------------------------------------------------------------------===// 72 // Parser and printer for Operand and type list 73 //===----------------------------------------------------------------------===// 74 75 /// Parse a list of operands with types. 76 /// 77 /// operand-and-type-list ::= `(` ssa-id-and-type-list `)` 78 /// ssa-id-and-type-list ::= ssa-id-and-type | 79 /// ssa-id-and-type `,` ssa-id-and-type-list 80 /// ssa-id-and-type ::= ssa-id `:` type 81 static ParseResult 82 parseOperandAndTypeList(OpAsmParser &parser, 83 SmallVectorImpl<OpAsmParser::OperandType> &operands, 84 SmallVectorImpl<Type> &types) { 85 return parser.parseCommaSeparatedList( 86 OpAsmParser::Delimiter::Paren, [&]() -> ParseResult { 87 OpAsmParser::OperandType operand; 88 Type type; 89 if (parser.parseOperand(operand) || parser.parseColonType(type)) 90 return failure(); 91 operands.push_back(operand); 92 types.push_back(type); 93 return success(); 94 }); 95 } 96 97 /// Print an operand and type list with parentheses 98 static void printOperandAndTypeList(OpAsmPrinter &p, OperandRange operands) { 99 p << "("; 100 llvm::interleaveComma( 101 operands, p, [&](const Value &v) { p << v << " : " << v.getType(); }); 102 p << ") "; 103 } 104 105 /// Print data variables corresponding to a data-sharing clause `name` 106 static void printDataVars(OpAsmPrinter &p, OperandRange operands, 107 StringRef name) { 108 if (operands.size()) { 109 p << name; 110 printOperandAndTypeList(p, operands); 111 } 112 } 113 114 //===----------------------------------------------------------------------===// 115 // Parser and printer for Allocate Clause 116 //===----------------------------------------------------------------------===// 117 118 /// Parse an allocate clause with allocators and a list of operands with types. 119 /// 120 /// allocate ::= `allocate` `(` allocate-operand-list `)` 121 /// allocate-operand-list :: = allocate-operand | 122 /// allocator-operand `,` allocate-operand-list 123 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type 124 /// ssa-id-and-type ::= ssa-id `:` type 125 static ParseResult parseAllocateAndAllocator( 126 OpAsmParser &parser, 127 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate, 128 SmallVectorImpl<Type> &typesAllocate, 129 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator, 130 SmallVectorImpl<Type> &typesAllocator) { 131 132 return parser.parseCommaSeparatedList( 133 OpAsmParser::Delimiter::Paren, [&]() -> ParseResult { 134 OpAsmParser::OperandType operand; 135 Type type; 136 if (parser.parseOperand(operand) || parser.parseColonType(type)) 137 return failure(); 138 operandsAllocator.push_back(operand); 139 typesAllocator.push_back(type); 140 if (parser.parseArrow()) 141 return failure(); 142 if (parser.parseOperand(operand) || parser.parseColonType(type)) 143 return failure(); 144 145 operandsAllocate.push_back(operand); 146 typesAllocate.push_back(type); 147 return success(); 148 }); 149 } 150 151 /// Print allocate clause 152 static void printAllocateAndAllocator(OpAsmPrinter &p, 153 OperandRange varsAllocate, 154 OperandRange varsAllocator) { 155 if (varsAllocate.empty()) 156 return; 157 158 p << "allocate("; 159 for (unsigned i = 0; i < varsAllocate.size(); ++i) { 160 std::string separator = i == varsAllocate.size() - 1 ? ") " : ", "; 161 p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> "; 162 p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator; 163 } 164 } 165 166 static LogicalResult verifyParallelOp(ParallelOp op) { 167 if (op.allocate_vars().size() != op.allocators_vars().size()) 168 return op.emitError( 169 "expected equal sizes for allocate and allocator variables"); 170 return success(); 171 } 172 173 static void printParallelOp(OpAsmPrinter &p, ParallelOp op) { 174 p << " "; 175 if (auto ifCond = op.if_expr_var()) 176 p << "if(" << ifCond << " : " << ifCond.getType() << ") "; 177 178 if (auto threads = op.num_threads_var()) 179 p << "num_threads(" << threads << " : " << threads.getType() << ") "; 180 181 printDataVars(p, op.private_vars(), "private"); 182 printDataVars(p, op.firstprivate_vars(), "firstprivate"); 183 printDataVars(p, op.shared_vars(), "shared"); 184 printDataVars(p, op.copyin_vars(), "copyin"); 185 printAllocateAndAllocator(p, op.allocate_vars(), op.allocators_vars()); 186 187 if (auto def = op.default_val()) 188 p << "default(" << def->drop_front(3) << ") "; 189 190 if (auto bind = op.proc_bind_val()) 191 p << "proc_bind(" << bind << ") "; 192 193 p.printRegion(op.getRegion()); 194 } 195 196 //===----------------------------------------------------------------------===// 197 // Parser and printer for Linear Clause 198 //===----------------------------------------------------------------------===// 199 200 /// linear ::= `linear` `(` linear-list `)` 201 /// linear-list := linear-val | linear-val linear-list 202 /// linear-val := ssa-id-and-type `=` ssa-id-and-type 203 static ParseResult 204 parseLinearClause(OpAsmParser &parser, 205 SmallVectorImpl<OpAsmParser::OperandType> &vars, 206 SmallVectorImpl<Type> &types, 207 SmallVectorImpl<OpAsmParser::OperandType> &stepVars) { 208 if (parser.parseLParen()) 209 return failure(); 210 211 do { 212 OpAsmParser::OperandType var; 213 Type type; 214 OpAsmParser::OperandType stepVar; 215 if (parser.parseOperand(var) || parser.parseEqual() || 216 parser.parseOperand(stepVar) || parser.parseColonType(type)) 217 return failure(); 218 219 vars.push_back(var); 220 types.push_back(type); 221 stepVars.push_back(stepVar); 222 } while (succeeded(parser.parseOptionalComma())); 223 224 if (parser.parseRParen()) 225 return failure(); 226 227 return success(); 228 } 229 230 /// Print Linear Clause 231 static void printLinearClause(OpAsmPrinter &p, OperandRange linearVars, 232 OperandRange linearStepVars) { 233 size_t linearVarsSize = linearVars.size(); 234 p << "("; 235 for (unsigned i = 0; i < linearVarsSize; ++i) { 236 std::string separator = i == linearVarsSize - 1 ? ") " : ", "; 237 p << linearVars[i]; 238 if (linearStepVars.size() > i) 239 p << " = " << linearStepVars[i]; 240 p << " : " << linearVars[i].getType() << separator; 241 } 242 } 243 244 //===----------------------------------------------------------------------===// 245 // Parser and printer for Schedule Clause 246 //===----------------------------------------------------------------------===// 247 248 /// schedule ::= `schedule` `(` sched-list `)` 249 /// sched-list ::= sched-val | sched-val sched-list 250 /// sched-val ::= sched-with-chunk | sched-wo-chunk 251 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)? 252 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided` 253 /// sched-wo-chunk ::= `auto` | `runtime` 254 static ParseResult 255 parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule, 256 Optional<OpAsmParser::OperandType> &chunkSize) { 257 if (parser.parseLParen()) 258 return failure(); 259 260 StringRef keyword; 261 if (parser.parseKeyword(&keyword)) 262 return failure(); 263 264 schedule = keyword; 265 if (keyword == "static" || keyword == "dynamic" || keyword == "guided") { 266 if (succeeded(parser.parseOptionalEqual())) { 267 chunkSize = OpAsmParser::OperandType{}; 268 if (parser.parseOperand(*chunkSize)) 269 return failure(); 270 } else { 271 chunkSize = llvm::NoneType::None; 272 } 273 } else if (keyword == "auto" || keyword == "runtime") { 274 chunkSize = llvm::NoneType::None; 275 } else { 276 return parser.emitError(parser.getNameLoc()) << " expected schedule kind"; 277 } 278 279 if (parser.parseRParen()) 280 return failure(); 281 282 return success(); 283 } 284 285 /// Print schedule clause 286 static void printScheduleClause(OpAsmPrinter &p, StringRef &sched, 287 Value scheduleChunkVar) { 288 std::string schedLower = sched.lower(); 289 p << "(" << schedLower; 290 if (scheduleChunkVar) 291 p << " = " << scheduleChunkVar; 292 p << ") "; 293 } 294 295 //===----------------------------------------------------------------------===// 296 // Parser, printer and verifier for ReductionVarList 297 //===----------------------------------------------------------------------===// 298 299 /// reduction ::= `reduction` `(` reduction-entry-list `)` 300 /// reduction-entry-list ::= reduction-entry 301 /// | reduction-entry-list `,` reduction-entry 302 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type 303 static ParseResult 304 parseReductionVarList(OpAsmParser &parser, 305 SmallVectorImpl<SymbolRefAttr> &symbols, 306 SmallVectorImpl<OpAsmParser::OperandType> &operands, 307 SmallVectorImpl<Type> &types) { 308 if (failed(parser.parseLParen())) 309 return failure(); 310 311 do { 312 if (parser.parseAttribute(symbols.emplace_back()) || parser.parseArrow() || 313 parser.parseOperand(operands.emplace_back()) || 314 parser.parseColonType(types.emplace_back())) 315 return failure(); 316 } while (succeeded(parser.parseOptionalComma())); 317 return parser.parseRParen(); 318 } 319 320 /// Print Reduction clause 321 static void printReductionVarList(OpAsmPrinter &p, 322 Optional<ArrayAttr> reductions, 323 OperandRange reduction_vars) { 324 for (unsigned i = 0, e = reductions->size(); i < e; ++i) { 325 if (i != 0) 326 p << ", "; 327 p << (*reductions)[i] << " -> " << reduction_vars[i] << " : " 328 << reduction_vars[i].getType(); 329 } 330 p << ") "; 331 } 332 333 /// Verifies Reduction Clause 334 static LogicalResult verifyReductionVarList(Operation *op, 335 Optional<ArrayAttr> reductions, 336 OperandRange reduction_vars) { 337 if (reduction_vars.size() != 0) { 338 if (!reductions || reductions->size() != reduction_vars.size()) 339 return op->emitOpError() 340 << "expected as many reduction symbol references " 341 "as reduction variables"; 342 } else { 343 if (reductions) 344 return op->emitOpError() << "unexpected reduction symbol references"; 345 return success(); 346 } 347 348 DenseSet<Value> accumulators; 349 for (auto args : llvm::zip(reduction_vars, *reductions)) { 350 Value accum = std::get<0>(args); 351 352 if (!accumulators.insert(accum).second) 353 return op->emitOpError() << "accumulator variable used more than once"; 354 355 Type varType = accum.getType().cast<PointerLikeType>(); 356 auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>(); 357 auto decl = 358 SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef); 359 if (!decl) 360 return op->emitOpError() << "expected symbol reference " << symbolRef 361 << " to point to a reduction declaration"; 362 363 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType) 364 return op->emitOpError() 365 << "expected accumulator (" << varType 366 << ") to be the same type as reduction declaration (" 367 << decl.getAccumulatorType() << ")"; 368 } 369 370 return success(); 371 } 372 373 //===----------------------------------------------------------------------===// 374 // Parser, printer and verifier for Synchronization Hint (2.17.12) 375 //===----------------------------------------------------------------------===// 376 377 /// Parses a Synchronization Hint clause. The value of hint is an integer 378 /// which is a combination of different hints from `omp_sync_hint_t`. 379 /// 380 /// hint-clause = `hint` `(` hint-value `)` 381 static ParseResult parseSynchronizationHint(OpAsmParser &parser, 382 IntegerAttr &hintAttr) { 383 if (failed(parser.parseOptionalKeyword("hint"))) { 384 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0); 385 return success(); 386 } 387 388 if (failed(parser.parseLParen())) 389 return failure(); 390 StringRef hintKeyword; 391 int64_t hint = 0; 392 do { 393 if (failed(parser.parseKeyword(&hintKeyword))) 394 return failure(); 395 if (hintKeyword == "uncontended") 396 hint |= 1; 397 else if (hintKeyword == "contended") 398 hint |= 2; 399 else if (hintKeyword == "nonspeculative") 400 hint |= 4; 401 else if (hintKeyword == "speculative") 402 hint |= 8; 403 else 404 return parser.emitError(parser.getCurrentLocation()) 405 << hintKeyword << " is not a valid hint"; 406 } while (succeeded(parser.parseOptionalComma())); 407 if (failed(parser.parseRParen())) 408 return failure(); 409 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint); 410 return success(); 411 } 412 413 /// Prints a Synchronization Hint clause 414 static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, 415 IntegerAttr hintAttr) { 416 int64_t hint = hintAttr.getInt(); 417 418 if (hint == 0) 419 return; 420 421 // Helper function to get n-th bit from the right end of `value` 422 auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; 423 424 bool uncontended = bitn(hint, 0); 425 bool contended = bitn(hint, 1); 426 bool nonspeculative = bitn(hint, 2); 427 bool speculative = bitn(hint, 3); 428 429 SmallVector<StringRef> hints; 430 if (uncontended) 431 hints.push_back("uncontended"); 432 if (contended) 433 hints.push_back("contended"); 434 if (nonspeculative) 435 hints.push_back("nonspeculative"); 436 if (speculative) 437 hints.push_back("speculative"); 438 439 p << "hint("; 440 llvm::interleaveComma(hints, p); 441 p << ")"; 442 } 443 444 /// Verifies a synchronization hint clause 445 static LogicalResult verifySynchronizationHint(Operation *op, int32_t hint) { 446 447 // Helper function to get n-th bit from the right end of `value` 448 auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; 449 450 bool uncontended = bitn(hint, 0); 451 bool contended = bitn(hint, 1); 452 bool nonspeculative = bitn(hint, 2); 453 bool speculative = bitn(hint, 3); 454 455 if (uncontended && contended) 456 return op->emitOpError() << "the hints omp_sync_hint_uncontended and " 457 "omp_sync_hint_contended cannot be combined"; 458 if (nonspeculative && speculative) 459 return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and " 460 "omp_sync_hint_speculative cannot be combined."; 461 return success(); 462 } 463 464 enum ClauseType { 465 ifClause, 466 numThreadsClause, 467 privateClause, 468 firstprivateClause, 469 lastprivateClause, 470 sharedClause, 471 copyinClause, 472 allocateClause, 473 defaultClause, 474 procBindClause, 475 reductionClause, 476 nowaitClause, 477 linearClause, 478 scheduleClause, 479 collapseClause, 480 orderClause, 481 orderedClause, 482 inclusiveClause, 483 COUNT 484 }; 485 486 //===----------------------------------------------------------------------===// 487 // Parser for Clause List 488 //===----------------------------------------------------------------------===// 489 490 /// Parse a list of clauses. The clauses can appear in any order, but their 491 /// operand segment indices are in the same order that they are passed in the 492 /// `clauses` list. The operand segments are added over the prevSegments 493 494 /// clause-list ::= clause clause-list | empty 495 /// clause ::= if | num-threads | private | firstprivate | lastprivate | 496 /// shared | copyin | allocate | default | proc-bind | reduction | 497 /// nowait | linear | schedule | collapse | order | ordered | 498 /// inclusive 499 /// if ::= `if` `(` ssa-id-and-type `)` 500 /// num-threads ::= `num_threads` `(` ssa-id-and-type `)` 501 /// private ::= `private` operand-and-type-list 502 /// firstprivate ::= `firstprivate` operand-and-type-list 503 /// lastprivate ::= `lastprivate` operand-and-type-list 504 /// shared ::= `shared` operand-and-type-list 505 /// copyin ::= `copyin` operand-and-type-list 506 /// allocate ::= `allocate` `(` allocate-operand-list `)` 507 /// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`) 508 /// proc-bind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)` 509 /// reduction ::= `reduction` `(` reduction-entry-list `)` 510 /// nowait ::= `nowait` 511 /// linear ::= `linear` `(` linear-list `)` 512 /// schedule ::= `schedule` `(` sched-list `)` 513 /// collapse ::= `collapse` `(` ssa-id-and-type `)` 514 /// order ::= `order` `(` `concurrent` `)` 515 /// ordered ::= `ordered` `(` ssa-id-and-type `)` 516 /// inclusive ::= `inclusive` 517 /// 518 /// Note that each clause can only appear once in the clase-list. 519 static ParseResult parseClauses(OpAsmParser &parser, OperationState &result, 520 SmallVectorImpl<ClauseType> &clauses, 521 SmallVectorImpl<int> &segments) { 522 523 // Check done[clause] to see if it has been parsed already 524 llvm::BitVector done(ClauseType::COUNT, false); 525 526 // See pos[clause] to get position of clause in operand segments 527 SmallVector<int> pos(ClauseType::COUNT, -1); 528 529 // Stores the last parsed clause keyword 530 StringRef clauseKeyword; 531 StringRef opName = result.name.getStringRef(); 532 533 // Containers for storing operands, types and attributes for various clauses 534 std::pair<OpAsmParser::OperandType, Type> ifCond; 535 std::pair<OpAsmParser::OperandType, Type> numThreads; 536 537 SmallVector<OpAsmParser::OperandType> privates, firstprivates, lastprivates, 538 shareds, copyins; 539 SmallVector<Type> privateTypes, firstprivateTypes, lastprivateTypes, 540 sharedTypes, copyinTypes; 541 542 SmallVector<OpAsmParser::OperandType> allocates, allocators; 543 SmallVector<Type> allocateTypes, allocatorTypes; 544 545 SmallVector<SymbolRefAttr> reductionSymbols; 546 SmallVector<OpAsmParser::OperandType> reductionVars; 547 SmallVector<Type> reductionVarTypes; 548 549 SmallVector<OpAsmParser::OperandType> linears; 550 SmallVector<Type> linearTypes; 551 SmallVector<OpAsmParser::OperandType> linearSteps; 552 553 SmallString<8> schedule; 554 Optional<OpAsmParser::OperandType> scheduleChunkSize; 555 556 // Compute the position of clauses in operand segments 557 int currPos = 0; 558 for (ClauseType clause : clauses) { 559 560 // Skip the following clauses - they do not take any position in operand 561 // segments 562 if (clause == defaultClause || clause == procBindClause || 563 clause == nowaitClause || clause == collapseClause || 564 clause == orderClause || clause == orderedClause || 565 clause == inclusiveClause) 566 continue; 567 568 pos[clause] = currPos++; 569 570 // For the following clauses, two positions are reserved in the operand 571 // segments 572 if (clause == allocateClause || clause == linearClause) 573 currPos++; 574 } 575 576 SmallVector<int> clauseSegments(currPos); 577 578 // Helper function to check if a clause is allowed/repeated or not 579 auto checkAllowed = [&](ClauseType clause, 580 bool allowRepeat = false) -> ParseResult { 581 if (!llvm::is_contained(clauses, clause)) 582 return parser.emitError(parser.getCurrentLocation()) 583 << clauseKeyword << "is not a valid clause for the " << opName 584 << " operation"; 585 if (done[clause] && !allowRepeat) 586 return parser.emitError(parser.getCurrentLocation()) 587 << "at most one " << clauseKeyword << " clause can appear on the " 588 << opName << " operation"; 589 done[clause] = true; 590 return success(); 591 }; 592 593 while (succeeded(parser.parseOptionalKeyword(&clauseKeyword))) { 594 if (clauseKeyword == "if") { 595 if (checkAllowed(ifClause) || parser.parseLParen() || 596 parser.parseOperand(ifCond.first) || 597 parser.parseColonType(ifCond.second) || parser.parseRParen()) 598 return failure(); 599 clauseSegments[pos[ifClause]] = 1; 600 } else if (clauseKeyword == "num_threads") { 601 if (checkAllowed(numThreadsClause) || parser.parseLParen() || 602 parser.parseOperand(numThreads.first) || 603 parser.parseColonType(numThreads.second) || parser.parseRParen()) 604 return failure(); 605 clauseSegments[pos[numThreadsClause]] = 1; 606 } else if (clauseKeyword == "private") { 607 if (checkAllowed(privateClause) || 608 parseOperandAndTypeList(parser, privates, privateTypes)) 609 return failure(); 610 clauseSegments[pos[privateClause]] = privates.size(); 611 } else if (clauseKeyword == "firstprivate") { 612 if (checkAllowed(firstprivateClause) || 613 parseOperandAndTypeList(parser, firstprivates, firstprivateTypes)) 614 return failure(); 615 clauseSegments[pos[firstprivateClause]] = firstprivates.size(); 616 } else if (clauseKeyword == "lastprivate") { 617 if (checkAllowed(lastprivateClause) || 618 parseOperandAndTypeList(parser, lastprivates, lastprivateTypes)) 619 return failure(); 620 clauseSegments[pos[lastprivateClause]] = lastprivates.size(); 621 } else if (clauseKeyword == "shared") { 622 if (checkAllowed(sharedClause) || 623 parseOperandAndTypeList(parser, shareds, sharedTypes)) 624 return failure(); 625 clauseSegments[pos[sharedClause]] = shareds.size(); 626 } else if (clauseKeyword == "copyin") { 627 if (checkAllowed(copyinClause) || 628 parseOperandAndTypeList(parser, copyins, copyinTypes)) 629 return failure(); 630 clauseSegments[pos[copyinClause]] = copyins.size(); 631 } else if (clauseKeyword == "allocate") { 632 if (checkAllowed(allocateClause) || 633 parseAllocateAndAllocator(parser, allocates, allocateTypes, 634 allocators, allocatorTypes)) 635 return failure(); 636 clauseSegments[pos[allocateClause]] = allocates.size(); 637 clauseSegments[pos[allocateClause] + 1] = allocators.size(); 638 } else if (clauseKeyword == "default") { 639 StringRef defval; 640 if (checkAllowed(defaultClause) || parser.parseLParen() || 641 parser.parseKeyword(&defval) || parser.parseRParen()) 642 return failure(); 643 // The def prefix is required for the attribute as "private" is a keyword 644 // in C++. 645 auto attr = parser.getBuilder().getStringAttr("def" + defval); 646 result.addAttribute("default_val", attr); 647 } else if (clauseKeyword == "proc_bind") { 648 StringRef bind; 649 if (checkAllowed(procBindClause) || parser.parseLParen() || 650 parser.parseKeyword(&bind) || parser.parseRParen()) 651 return failure(); 652 auto attr = parser.getBuilder().getStringAttr(bind); 653 result.addAttribute("proc_bind_val", attr); 654 } else if (clauseKeyword == "reduction") { 655 if (checkAllowed(reductionClause) || 656 parseReductionVarList(parser, reductionSymbols, reductionVars, 657 reductionVarTypes)) 658 return failure(); 659 clauseSegments[pos[reductionClause]] = reductionVars.size(); 660 } else if (clauseKeyword == "nowait") { 661 if (checkAllowed(nowaitClause)) 662 return failure(); 663 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 664 result.addAttribute("nowait", attr); 665 } else if (clauseKeyword == "linear") { 666 if (checkAllowed(linearClause) || 667 parseLinearClause(parser, linears, linearTypes, linearSteps)) 668 return failure(); 669 clauseSegments[pos[linearClause]] = linears.size(); 670 clauseSegments[pos[linearClause] + 1] = linearSteps.size(); 671 } else if (clauseKeyword == "schedule") { 672 if (checkAllowed(scheduleClause) || 673 parseScheduleClause(parser, schedule, scheduleChunkSize)) 674 return failure(); 675 if (scheduleChunkSize) { 676 clauseSegments[pos[scheduleClause]] = 1; 677 } 678 } else if (clauseKeyword == "collapse") { 679 auto type = parser.getBuilder().getI64Type(); 680 mlir::IntegerAttr attr; 681 if (checkAllowed(collapseClause) || parser.parseLParen() || 682 parser.parseAttribute(attr, type) || parser.parseRParen()) 683 return failure(); 684 result.addAttribute("collapse_val", attr); 685 } else if (clauseKeyword == "ordered") { 686 mlir::IntegerAttr attr; 687 if (checkAllowed(orderedClause)) 688 return failure(); 689 if (succeeded(parser.parseOptionalLParen())) { 690 auto type = parser.getBuilder().getI64Type(); 691 if (parser.parseAttribute(attr, type) || parser.parseRParen()) 692 return failure(); 693 } else { 694 // Use 0 to represent no ordered parameter was specified 695 attr = parser.getBuilder().getI64IntegerAttr(0); 696 } 697 result.addAttribute("ordered_val", attr); 698 } else if (clauseKeyword == "order") { 699 StringRef order; 700 if (checkAllowed(orderClause) || parser.parseLParen() || 701 parser.parseKeyword(&order) || parser.parseRParen()) 702 return failure(); 703 auto attr = parser.getBuilder().getStringAttr(order); 704 result.addAttribute("order", attr); 705 } else if (clauseKeyword == "inclusive") { 706 if (checkAllowed(inclusiveClause)) 707 return failure(); 708 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 709 result.addAttribute("inclusive", attr); 710 } else { 711 return parser.emitError(parser.getNameLoc()) 712 << clauseKeyword << " is not a valid clause"; 713 } 714 } 715 716 // Add if parameter. 717 if (done[ifClause] && clauseSegments[pos[ifClause]] && 718 failed( 719 parser.resolveOperand(ifCond.first, ifCond.second, result.operands))) 720 return failure(); 721 722 // Add num_threads parameter. 723 if (done[numThreadsClause] && clauseSegments[pos[numThreadsClause]] && 724 failed(parser.resolveOperand(numThreads.first, numThreads.second, 725 result.operands))) 726 return failure(); 727 728 // Add private parameters. 729 if (done[privateClause] && clauseSegments[pos[privateClause]] && 730 failed(parser.resolveOperands(privates, privateTypes, 731 privates[0].location, result.operands))) 732 return failure(); 733 734 // Add firstprivate parameters. 735 if (done[firstprivateClause] && clauseSegments[pos[firstprivateClause]] && 736 failed(parser.resolveOperands(firstprivates, firstprivateTypes, 737 firstprivates[0].location, 738 result.operands))) 739 return failure(); 740 741 // Add lastprivate parameters. 742 if (done[lastprivateClause] && clauseSegments[pos[lastprivateClause]] && 743 failed(parser.resolveOperands(lastprivates, lastprivateTypes, 744 lastprivates[0].location, result.operands))) 745 return failure(); 746 747 // Add shared parameters. 748 if (done[sharedClause] && clauseSegments[pos[sharedClause]] && 749 failed(parser.resolveOperands(shareds, sharedTypes, shareds[0].location, 750 result.operands))) 751 return failure(); 752 753 // Add copyin parameters. 754 if (done[copyinClause] && clauseSegments[pos[copyinClause]] && 755 failed(parser.resolveOperands(copyins, copyinTypes, copyins[0].location, 756 result.operands))) 757 return failure(); 758 759 // Add allocate parameters. 760 if (done[allocateClause] && clauseSegments[pos[allocateClause]] && 761 failed(parser.resolveOperands(allocates, allocateTypes, 762 allocates[0].location, result.operands))) 763 return failure(); 764 765 // Add allocator parameters. 766 if (done[allocateClause] && clauseSegments[pos[allocateClause] + 1] && 767 failed(parser.resolveOperands(allocators, allocatorTypes, 768 allocators[0].location, result.operands))) 769 return failure(); 770 771 // Add reduction parameters and symbols 772 if (done[reductionClause] && clauseSegments[pos[reductionClause]]) { 773 if (failed(parser.resolveOperands(reductionVars, reductionVarTypes, 774 parser.getNameLoc(), result.operands))) 775 return failure(); 776 777 SmallVector<Attribute> reductions(reductionSymbols.begin(), 778 reductionSymbols.end()); 779 result.addAttribute("reductions", 780 parser.getBuilder().getArrayAttr(reductions)); 781 } 782 783 // Add linear parameters 784 if (done[linearClause] && clauseSegments[pos[linearClause]]) { 785 auto linearStepType = parser.getBuilder().getI32Type(); 786 SmallVector<Type> linearStepTypes(linearSteps.size(), linearStepType); 787 if (failed(parser.resolveOperands(linears, linearTypes, linears[0].location, 788 result.operands)) || 789 failed(parser.resolveOperands(linearSteps, linearStepTypes, 790 linearSteps[0].location, 791 result.operands))) 792 return failure(); 793 } 794 795 // Add schedule parameters 796 if (done[scheduleClause] && !schedule.empty()) { 797 schedule[0] = llvm::toUpper(schedule[0]); 798 auto attr = parser.getBuilder().getStringAttr(schedule); 799 result.addAttribute("schedule_val", attr); 800 if (scheduleChunkSize) { 801 auto chunkSizeType = parser.getBuilder().getI32Type(); 802 parser.resolveOperand(*scheduleChunkSize, chunkSizeType, result.operands); 803 } 804 } 805 806 segments.insert(segments.end(), clauseSegments.begin(), clauseSegments.end()); 807 808 return success(); 809 } 810 811 /// Parses a parallel operation. 812 /// 813 /// operation ::= `omp.parallel` clause-list 814 /// clause-list ::= clause | clause clause-list 815 /// clause ::= if | num-threads | private | firstprivate | shared | copyin | 816 /// allocate | default | proc-bind 817 /// 818 static ParseResult parseParallelOp(OpAsmParser &parser, 819 OperationState &result) { 820 SmallVector<ClauseType> clauses = { 821 ifClause, numThreadsClause, privateClause, 822 firstprivateClause, sharedClause, copyinClause, 823 allocateClause, defaultClause, procBindClause}; 824 825 SmallVector<int> segments; 826 827 if (failed(parseClauses(parser, result, clauses, segments))) 828 return failure(); 829 830 result.addAttribute("operand_segment_sizes", 831 parser.getBuilder().getI32VectorAttr(segments)); 832 833 Region *body = result.addRegion(); 834 SmallVector<OpAsmParser::OperandType> regionArgs; 835 SmallVector<Type> regionArgTypes; 836 if (parser.parseRegion(*body, regionArgs, regionArgTypes)) 837 return failure(); 838 return success(); 839 } 840 841 /// Parses an OpenMP Workshare Loop operation 842 /// 843 /// wsloop ::= `omp.wsloop` loop-control clause-list 844 /// loop-control ::= `(` ssa-id-list `)` `:` type `=` loop-bounds 845 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` steps 846 /// steps := `step` `(`ssa-id-list`)` 847 /// clause-list ::= clause clause-list | empty 848 /// clause ::= private | firstprivate | lastprivate | linear | schedule | 849 // collapse | nowait | ordered | order | inclusive | reduction 850 static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) { 851 852 // Parse an opening `(` followed by induction variables followed by `)` 853 SmallVector<OpAsmParser::OperandType> ivs; 854 if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1, 855 OpAsmParser::Delimiter::Paren)) 856 return failure(); 857 858 int numIVs = static_cast<int>(ivs.size()); 859 Type loopVarType; 860 if (parser.parseColonType(loopVarType)) 861 return failure(); 862 863 // Parse loop bounds. 864 SmallVector<OpAsmParser::OperandType> lower; 865 if (parser.parseEqual() || 866 parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) || 867 parser.resolveOperands(lower, loopVarType, result.operands)) 868 return failure(); 869 870 SmallVector<OpAsmParser::OperandType> upper; 871 if (parser.parseKeyword("to") || 872 parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) || 873 parser.resolveOperands(upper, loopVarType, result.operands)) 874 return failure(); 875 876 // Parse step values. 877 SmallVector<OpAsmParser::OperandType> steps; 878 if (parser.parseKeyword("step") || 879 parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) || 880 parser.resolveOperands(steps, loopVarType, result.operands)) 881 return failure(); 882 883 SmallVector<ClauseType> clauses = { 884 privateClause, firstprivateClause, lastprivateClause, linearClause, 885 reductionClause, collapseClause, orderClause, orderedClause, 886 nowaitClause, scheduleClause}; 887 SmallVector<int> segments{numIVs, numIVs, numIVs}; 888 if (failed(parseClauses(parser, result, clauses, segments))) 889 return failure(); 890 891 result.addAttribute("operand_segment_sizes", 892 parser.getBuilder().getI32VectorAttr(segments)); 893 894 // Now parse the body. 895 Region *body = result.addRegion(); 896 SmallVector<Type> ivTypes(numIVs, loopVarType); 897 SmallVector<OpAsmParser::OperandType> blockArgs(ivs); 898 if (parser.parseRegion(*body, blockArgs, ivTypes)) 899 return failure(); 900 return success(); 901 } 902 903 static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) { 904 auto args = op.getRegion().front().getArguments(); 905 p << " (" << args << ") : " << args[0].getType() << " = (" << op.lowerBound() 906 << ") to (" << op.upperBound() << ") step (" << op.step() << ") "; 907 908 printDataVars(p, op.private_vars(), "private"); 909 printDataVars(p, op.firstprivate_vars(), "firstprivate"); 910 printDataVars(p, op.lastprivate_vars(), "lastprivate"); 911 912 if (op.linear_vars().size()) { 913 p << "linear"; 914 printLinearClause(p, op.linear_vars(), op.linear_step_vars()); 915 } 916 917 if (auto sched = op.schedule_val()) { 918 p << "schedule"; 919 printScheduleClause(p, sched.getValue(), op.schedule_chunk_var()); 920 } 921 922 if (auto collapse = op.collapse_val()) 923 p << "collapse(" << collapse << ") "; 924 925 if (op.nowait()) 926 p << "nowait "; 927 928 if (auto ordered = op.ordered_val()) 929 p << "ordered(" << ordered << ") "; 930 931 if (!op.reduction_vars().empty()) { 932 p << "reduction("; 933 printReductionVarList(p, op.reductions(), op.reduction_vars()); 934 } 935 936 if (op.inclusive()) { 937 p << "inclusive "; 938 } 939 940 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 941 } 942 943 //===----------------------------------------------------------------------===// 944 // ReductionOp 945 //===----------------------------------------------------------------------===// 946 947 static ParseResult parseAtomicReductionRegion(OpAsmParser &parser, 948 Region ®ion) { 949 if (parser.parseOptionalKeyword("atomic")) 950 return success(); 951 return parser.parseRegion(region); 952 } 953 954 static void printAtomicReductionRegion(OpAsmPrinter &printer, 955 ReductionDeclareOp op, Region ®ion) { 956 if (region.empty()) 957 return; 958 printer << "atomic "; 959 printer.printRegion(region); 960 } 961 962 static LogicalResult verifyReductionDeclareOp(ReductionDeclareOp op) { 963 if (op.initializerRegion().empty()) 964 return op.emitOpError() << "expects non-empty initializer region"; 965 Block &initializerEntryBlock = op.initializerRegion().front(); 966 if (initializerEntryBlock.getNumArguments() != 1 || 967 initializerEntryBlock.getArgument(0).getType() != op.type()) { 968 return op.emitOpError() << "expects initializer region with one argument " 969 "of the reduction type"; 970 } 971 972 for (YieldOp yieldOp : op.initializerRegion().getOps<YieldOp>()) { 973 if (yieldOp.results().size() != 1 || 974 yieldOp.results().getTypes()[0] != op.type()) 975 return op.emitOpError() << "expects initializer region to yield a value " 976 "of the reduction type"; 977 } 978 979 if (op.reductionRegion().empty()) 980 return op.emitOpError() << "expects non-empty reduction region"; 981 Block &reductionEntryBlock = op.reductionRegion().front(); 982 if (reductionEntryBlock.getNumArguments() != 2 || 983 reductionEntryBlock.getArgumentTypes()[0] != 984 reductionEntryBlock.getArgumentTypes()[1] || 985 reductionEntryBlock.getArgumentTypes()[0] != op.type()) 986 return op.emitOpError() << "expects reduction region with two arguments of " 987 "the reduction type"; 988 for (YieldOp yieldOp : op.reductionRegion().getOps<YieldOp>()) { 989 if (yieldOp.results().size() != 1 || 990 yieldOp.results().getTypes()[0] != op.type()) 991 return op.emitOpError() << "expects reduction region to yield a value " 992 "of the reduction type"; 993 } 994 995 if (op.atomicReductionRegion().empty()) 996 return success(); 997 998 Block &atomicReductionEntryBlock = op.atomicReductionRegion().front(); 999 if (atomicReductionEntryBlock.getNumArguments() != 2 || 1000 atomicReductionEntryBlock.getArgumentTypes()[0] != 1001 atomicReductionEntryBlock.getArgumentTypes()[1]) 1002 return op.emitOpError() << "expects atomic reduction region with two " 1003 "arguments of the same type"; 1004 auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0] 1005 .dyn_cast<PointerLikeType>(); 1006 if (!ptrType || ptrType.getElementType() != op.type()) 1007 return op.emitOpError() << "expects atomic reduction region arguments to " 1008 "be accumulators containing the reduction type"; 1009 return success(); 1010 } 1011 1012 static LogicalResult verifyReductionOp(ReductionOp op) { 1013 // TODO: generalize this to an op interface when there is more than one op 1014 // that supports reductions. 1015 auto container = op->getParentOfType<WsLoopOp>(); 1016 for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i) 1017 if (container.reduction_vars()[i] == op.accumulator()) 1018 return success(); 1019 1020 return op.emitOpError() << "the accumulator is not used by the parent"; 1021 } 1022 1023 //===----------------------------------------------------------------------===// 1024 // WsLoopOp 1025 //===----------------------------------------------------------------------===// 1026 1027 void WsLoopOp::build(OpBuilder &builder, OperationState &state, 1028 ValueRange lowerBound, ValueRange upperBound, 1029 ValueRange step, ArrayRef<NamedAttribute> attributes) { 1030 build(builder, state, TypeRange(), lowerBound, upperBound, step, 1031 /*private_vars=*/ValueRange(), 1032 /*firstprivate_vars=*/ValueRange(), /*lastprivate_vars=*/ValueRange(), 1033 /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(), 1034 /*reduction_vars=*/ValueRange(), /*schedule_val=*/nullptr, 1035 /*schedule_chunk_var=*/nullptr, /*collapse_val=*/nullptr, 1036 /*nowait=*/nullptr, /*ordered_val=*/nullptr, /*order_val=*/nullptr, 1037 /*inclusive=*/nullptr, /*buildBody=*/false); 1038 state.addAttributes(attributes); 1039 } 1040 1041 void WsLoopOp::build(OpBuilder &, OperationState &state, TypeRange resultTypes, 1042 ValueRange operands, ArrayRef<NamedAttribute> attributes) { 1043 state.addOperands(operands); 1044 state.addAttributes(attributes); 1045 (void)state.addRegion(); 1046 assert(resultTypes.empty() && "mismatched number of return types"); 1047 state.addTypes(resultTypes); 1048 } 1049 1050 void WsLoopOp::build(OpBuilder &builder, OperationState &result, 1051 TypeRange typeRange, ValueRange lowerBounds, 1052 ValueRange upperBounds, ValueRange steps, 1053 ValueRange privateVars, ValueRange firstprivateVars, 1054 ValueRange lastprivateVars, ValueRange linearVars, 1055 ValueRange linearStepVars, ValueRange reductionVars, 1056 StringAttr scheduleVal, Value scheduleChunkVar, 1057 IntegerAttr collapseVal, UnitAttr nowait, 1058 IntegerAttr orderedVal, StringAttr orderVal, 1059 UnitAttr inclusive, bool buildBody) { 1060 result.addOperands(lowerBounds); 1061 result.addOperands(upperBounds); 1062 result.addOperands(steps); 1063 result.addOperands(privateVars); 1064 result.addOperands(firstprivateVars); 1065 result.addOperands(linearVars); 1066 result.addOperands(linearStepVars); 1067 if (scheduleChunkVar) 1068 result.addOperands(scheduleChunkVar); 1069 1070 if (scheduleVal) 1071 result.addAttribute("schedule_val", scheduleVal); 1072 if (collapseVal) 1073 result.addAttribute("collapse_val", collapseVal); 1074 if (nowait) 1075 result.addAttribute("nowait", nowait); 1076 if (orderedVal) 1077 result.addAttribute("ordered_val", orderedVal); 1078 if (orderVal) 1079 result.addAttribute("order", orderVal); 1080 if (inclusive) 1081 result.addAttribute("inclusive", inclusive); 1082 result.addAttribute( 1083 WsLoopOp::getOperandSegmentSizeAttr(), 1084 builder.getI32VectorAttr( 1085 {static_cast<int32_t>(lowerBounds.size()), 1086 static_cast<int32_t>(upperBounds.size()), 1087 static_cast<int32_t>(steps.size()), 1088 static_cast<int32_t>(privateVars.size()), 1089 static_cast<int32_t>(firstprivateVars.size()), 1090 static_cast<int32_t>(lastprivateVars.size()), 1091 static_cast<int32_t>(linearVars.size()), 1092 static_cast<int32_t>(linearStepVars.size()), 1093 static_cast<int32_t>(reductionVars.size()), 1094 static_cast<int32_t>(scheduleChunkVar != nullptr ? 1 : 0)})); 1095 1096 Region *bodyRegion = result.addRegion(); 1097 if (buildBody) { 1098 OpBuilder::InsertionGuard guard(builder); 1099 unsigned numIVs = steps.size(); 1100 SmallVector<Type, 8> argTypes(numIVs, steps.getType().front()); 1101 builder.createBlock(bodyRegion, {}, argTypes); 1102 } 1103 } 1104 1105 static LogicalResult verifyWsLoopOp(WsLoopOp op) { 1106 return verifyReductionVarList(op, op.reductions(), op.reduction_vars()); 1107 } 1108 1109 //===----------------------------------------------------------------------===// 1110 // Verifier for critical construct (2.17.1) 1111 //===----------------------------------------------------------------------===// 1112 1113 static LogicalResult verifyCriticalDeclareOp(CriticalDeclareOp op) { 1114 return verifySynchronizationHint(op, op.hint()); 1115 } 1116 1117 static LogicalResult verifyCriticalOp(CriticalOp op) { 1118 1119 if (op.nameAttr()) { 1120 auto symbolRef = op.nameAttr().cast<SymbolRefAttr>(); 1121 auto decl = 1122 SymbolTable::lookupNearestSymbolFrom<CriticalDeclareOp>(op, symbolRef); 1123 if (!decl) { 1124 return op.emitOpError() << "expected symbol reference " << symbolRef 1125 << " to point to a critical declaration"; 1126 } 1127 } 1128 1129 return success(); 1130 } 1131 1132 //===----------------------------------------------------------------------===// 1133 // Verifier for ordered construct 1134 //===----------------------------------------------------------------------===// 1135 1136 static LogicalResult verifyOrderedOp(OrderedOp op) { 1137 auto container = op->getParentOfType<WsLoopOp>(); 1138 if (!container || !container.ordered_valAttr() || 1139 container.ordered_valAttr().getInt() == 0) 1140 return op.emitOpError() << "ordered depend directive must be closely " 1141 << "nested inside a worksharing-loop with ordered " 1142 << "clause with parameter present"; 1143 1144 if (container.ordered_valAttr().getInt() != 1145 (int64_t)op.num_loops_val().getValue()) 1146 return op.emitOpError() << "number of variables in depend clause does not " 1147 << "match number of iteration variables in the " 1148 << "doacross loop"; 1149 1150 return success(); 1151 } 1152 1153 static LogicalResult verifyOrderedRegionOp(OrderedRegionOp op) { 1154 // TODO: The code generation for ordered simd directive is not supported yet. 1155 if (op.simd()) 1156 return failure(); 1157 1158 if (auto container = op->getParentOfType<WsLoopOp>()) { 1159 if (!container.ordered_valAttr() || 1160 container.ordered_valAttr().getInt() != 0) 1161 return op.emitOpError() << "ordered region must be closely nested inside " 1162 << "a worksharing-loop region with an ordered " 1163 << "clause without parameter present"; 1164 } 1165 1166 return success(); 1167 } 1168 1169 #define GET_OP_CLASSES 1170 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 1171