1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the OpenMP dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/IR/Attributes.h" 17 #include "mlir/IR/OpImplementation.h" 18 #include "mlir/IR/OperationSupport.h" 19 20 #include "llvm/ADT/SmallString.h" 21 #include "llvm/ADT/StringExtras.h" 22 #include "llvm/ADT/StringRef.h" 23 #include "llvm/ADT/StringSwitch.h" 24 #include <cstddef> 25 26 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc" 27 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc" 28 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc" 29 30 using namespace mlir; 31 using namespace mlir::omp; 32 33 namespace { 34 /// Model for pointer-like types that already provide a `getElementType` method. 35 template <typename T> 36 struct PointerLikeModel 37 : public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> { 38 Type getElementType(Type pointer) const { 39 return pointer.cast<T>().getElementType(); 40 } 41 }; 42 } // end namespace 43 44 void OpenMPDialect::initialize() { 45 addOperations< 46 #define GET_OP_LIST 47 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 48 >(); 49 50 LLVM::LLVMPointerType::attachInterface< 51 PointerLikeModel<LLVM::LLVMPointerType>>(*getContext()); 52 MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext()); 53 } 54 55 //===----------------------------------------------------------------------===// 56 // ParallelOp 57 //===----------------------------------------------------------------------===// 58 59 void ParallelOp::build(OpBuilder &builder, OperationState &state, 60 ArrayRef<NamedAttribute> attributes) { 61 ParallelOp::build( 62 builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr, 63 /*default_val=*/nullptr, /*private_vars=*/ValueRange(), 64 /*firstprivate_vars=*/ValueRange(), /*shared_vars=*/ValueRange(), 65 /*copyin_vars=*/ValueRange(), /*allocate_vars=*/ValueRange(), 66 /*allocators_vars=*/ValueRange(), /*proc_bind_val=*/nullptr); 67 state.addAttributes(attributes); 68 } 69 70 /// Parse a list of operands with types. 71 /// 72 /// operand-and-type-list ::= `(` ssa-id-and-type-list `)` 73 /// ssa-id-and-type-list ::= ssa-id-and-type | 74 /// ssa-id-and-type `,` ssa-id-and-type-list 75 /// ssa-id-and-type ::= ssa-id `:` type 76 static ParseResult 77 parseOperandAndTypeList(OpAsmParser &parser, 78 SmallVectorImpl<OpAsmParser::OperandType> &operands, 79 SmallVectorImpl<Type> &types) { 80 return parser.parseCommaSeparatedList( 81 OpAsmParser::Delimiter::Paren, [&]() -> ParseResult { 82 OpAsmParser::OperandType operand; 83 Type type; 84 if (parser.parseOperand(operand) || parser.parseColonType(type)) 85 return failure(); 86 operands.push_back(operand); 87 types.push_back(type); 88 return success(); 89 }); 90 } 91 92 /// Parse an allocate clause with allocators and a list of operands with types. 93 /// 94 /// operand-and-type-list ::= `(` allocate-operand-list `)` 95 /// allocate-operand-list :: = allocate-operand | 96 /// allocator-operand `,` allocate-operand-list 97 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type 98 /// ssa-id-and-type ::= ssa-id `:` type 99 static ParseResult parseAllocateAndAllocator( 100 OpAsmParser &parser, 101 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate, 102 SmallVectorImpl<Type> &typesAllocate, 103 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator, 104 SmallVectorImpl<Type> &typesAllocator) { 105 106 return parser.parseCommaSeparatedList( 107 OpAsmParser::Delimiter::Paren, [&]() -> ParseResult { 108 OpAsmParser::OperandType operand; 109 Type type; 110 if (parser.parseOperand(operand) || parser.parseColonType(type)) 111 return failure(); 112 operandsAllocator.push_back(operand); 113 typesAllocator.push_back(type); 114 if (parser.parseArrow()) 115 return failure(); 116 if (parser.parseOperand(operand) || parser.parseColonType(type)) 117 return failure(); 118 119 operandsAllocate.push_back(operand); 120 typesAllocate.push_back(type); 121 return success(); 122 }); 123 } 124 125 static LogicalResult verifyParallelOp(ParallelOp op) { 126 if (op.allocate_vars().size() != op.allocators_vars().size()) 127 return op.emitError( 128 "expected equal sizes for allocate and allocator variables"); 129 return success(); 130 } 131 132 static void printParallelOp(OpAsmPrinter &p, ParallelOp op) { 133 if (auto ifCond = op.if_expr_var()) 134 p << " if(" << ifCond << " : " << ifCond.getType() << ")"; 135 136 if (auto threads = op.num_threads_var()) 137 p << " num_threads(" << threads << " : " << threads.getType() << ")"; 138 139 // Print private, firstprivate, shared and copyin parameters 140 auto printDataVars = [&p](StringRef name, OperandRange vars) { 141 if (vars.size()) { 142 p << " " << name << "("; 143 for (unsigned i = 0; i < vars.size(); ++i) { 144 std::string separator = i == vars.size() - 1 ? ")" : ", "; 145 p << vars[i] << " : " << vars[i].getType() << separator; 146 } 147 } 148 }; 149 150 // Print allocator and allocate parameters 151 auto printAllocateAndAllocator = [&p](OperandRange varsAllocate, 152 OperandRange varsAllocator) { 153 if (varsAllocate.empty()) 154 return; 155 156 p << " allocate("; 157 for (unsigned i = 0; i < varsAllocate.size(); ++i) { 158 std::string separator = i == varsAllocate.size() - 1 ? ")" : ", "; 159 p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> "; 160 p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator; 161 } 162 }; 163 164 printDataVars("private", op.private_vars()); 165 printDataVars("firstprivate", op.firstprivate_vars()); 166 printDataVars("shared", op.shared_vars()); 167 printDataVars("copyin", op.copyin_vars()); 168 printAllocateAndAllocator(op.allocate_vars(), op.allocators_vars()); 169 170 if (auto def = op.default_val()) 171 p << " default(" << def->drop_front(3) << ")"; 172 173 if (auto bind = op.proc_bind_val()) 174 p << " proc_bind(" << bind << ")"; 175 176 p.printRegion(op.getRegion()); 177 } 178 179 /// Emit an error if the same clause is present more than once on an operation. 180 static ParseResult allowedOnce(OpAsmParser &parser, StringRef clause, 181 StringRef operation) { 182 return parser.emitError(parser.getNameLoc()) 183 << " at most one " << clause << " clause can appear on the " 184 << operation << " operation"; 185 } 186 187 /// Parses a parallel operation. 188 /// 189 /// operation ::= `omp.parallel` clause-list 190 /// clause-list ::= clause | clause clause-list 191 /// clause ::= if | numThreads | private | firstprivate | shared | copyin | 192 /// default | procBind 193 /// if ::= `if` `(` ssa-id `)` 194 /// numThreads ::= `num_threads` `(` ssa-id-and-type `)` 195 /// private ::= `private` operand-and-type-list 196 /// firstprivate ::= `firstprivate` operand-and-type-list 197 /// shared ::= `shared` operand-and-type-list 198 /// copyin ::= `copyin` operand-and-type-list 199 /// allocate ::= `allocate` operand-and-type `->` operand-and-type-list 200 /// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`) 201 /// procBind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)` 202 /// 203 /// Note that each clause can only appear once in the clase-list. 204 static ParseResult parseParallelOp(OpAsmParser &parser, 205 OperationState &result) { 206 std::pair<OpAsmParser::OperandType, Type> ifCond; 207 std::pair<OpAsmParser::OperandType, Type> numThreads; 208 SmallVector<OpAsmParser::OperandType, 4> privates; 209 SmallVector<Type, 4> privateTypes; 210 SmallVector<OpAsmParser::OperandType, 4> firstprivates; 211 SmallVector<Type, 4> firstprivateTypes; 212 SmallVector<OpAsmParser::OperandType, 4> shareds; 213 SmallVector<Type, 4> sharedTypes; 214 SmallVector<OpAsmParser::OperandType, 4> copyins; 215 SmallVector<Type, 4> copyinTypes; 216 SmallVector<OpAsmParser::OperandType, 4> allocates; 217 SmallVector<Type, 4> allocateTypes; 218 SmallVector<OpAsmParser::OperandType, 4> allocators; 219 SmallVector<Type, 4> allocatorTypes; 220 std::array<int, 8> segments{0, 0, 0, 0, 0, 0, 0, 0}; 221 StringRef keyword; 222 bool defaultVal = false; 223 bool procBind = false; 224 225 const int ifClausePos = 0; 226 const int numThreadsClausePos = 1; 227 const int privateClausePos = 2; 228 const int firstprivateClausePos = 3; 229 const int sharedClausePos = 4; 230 const int copyinClausePos = 5; 231 const int allocateClausePos = 6; 232 const int allocatorPos = 7; 233 const StringRef opName = result.name.getStringRef(); 234 235 while (succeeded(parser.parseOptionalKeyword(&keyword))) { 236 if (keyword == "if") { 237 // Fail if there was already another if condition. 238 if (segments[ifClausePos]) 239 return allowedOnce(parser, "if", opName); 240 if (parser.parseLParen() || parser.parseOperand(ifCond.first) || 241 parser.parseColonType(ifCond.second) || parser.parseRParen()) 242 return failure(); 243 segments[ifClausePos] = 1; 244 } else if (keyword == "num_threads") { 245 // Fail if there was already another num_threads clause. 246 if (segments[numThreadsClausePos]) 247 return allowedOnce(parser, "num_threads", opName); 248 if (parser.parseLParen() || parser.parseOperand(numThreads.first) || 249 parser.parseColonType(numThreads.second) || parser.parseRParen()) 250 return failure(); 251 segments[numThreadsClausePos] = 1; 252 } else if (keyword == "private") { 253 // Fail if there was already another private clause. 254 if (segments[privateClausePos]) 255 return allowedOnce(parser, "private", opName); 256 if (parseOperandAndTypeList(parser, privates, privateTypes)) 257 return failure(); 258 segments[privateClausePos] = privates.size(); 259 } else if (keyword == "firstprivate") { 260 // Fail if there was already another firstprivate clause. 261 if (segments[firstprivateClausePos]) 262 return allowedOnce(parser, "firstprivate", opName); 263 if (parseOperandAndTypeList(parser, firstprivates, firstprivateTypes)) 264 return failure(); 265 segments[firstprivateClausePos] = firstprivates.size(); 266 } else if (keyword == "shared") { 267 // Fail if there was already another shared clause. 268 if (segments[sharedClausePos]) 269 return allowedOnce(parser, "shared", opName); 270 if (parseOperandAndTypeList(parser, shareds, sharedTypes)) 271 return failure(); 272 segments[sharedClausePos] = shareds.size(); 273 } else if (keyword == "copyin") { 274 // Fail if there was already another copyin clause. 275 if (segments[copyinClausePos]) 276 return allowedOnce(parser, "copyin", opName); 277 if (parseOperandAndTypeList(parser, copyins, copyinTypes)) 278 return failure(); 279 segments[copyinClausePos] = copyins.size(); 280 } else if (keyword == "allocate") { 281 // Fail if there was already another allocate clause. 282 if (segments[allocateClausePos]) 283 return allowedOnce(parser, "allocate", opName); 284 if (parseAllocateAndAllocator(parser, allocates, allocateTypes, 285 allocators, allocatorTypes)) 286 return failure(); 287 segments[allocateClausePos] = allocates.size(); 288 segments[allocatorPos] = allocators.size(); 289 } else if (keyword == "default") { 290 // Fail if there was already another default clause. 291 if (defaultVal) 292 return allowedOnce(parser, "default", opName); 293 defaultVal = true; 294 StringRef defval; 295 if (parser.parseLParen() || parser.parseKeyword(&defval) || 296 parser.parseRParen()) 297 return failure(); 298 // The def prefix is required for the attribute as "private" is a keyword 299 // in C++. 300 auto attr = parser.getBuilder().getStringAttr("def" + defval); 301 result.addAttribute("default_val", attr); 302 } else if (keyword == "proc_bind") { 303 // Fail if there was already another proc_bind clause. 304 if (procBind) 305 return allowedOnce(parser, "proc_bind", opName); 306 procBind = true; 307 StringRef bind; 308 if (parser.parseLParen() || parser.parseKeyword(&bind) || 309 parser.parseRParen()) 310 return failure(); 311 auto attr = parser.getBuilder().getStringAttr(bind); 312 result.addAttribute("proc_bind_val", attr); 313 } else { 314 return parser.emitError(parser.getNameLoc()) 315 << keyword << " is not a valid clause for the " << opName 316 << " operation"; 317 } 318 } 319 320 // Add if parameter. 321 if (segments[ifClausePos] && 322 parser.resolveOperand(ifCond.first, ifCond.second, result.operands)) 323 return failure(); 324 325 // Add num_threads parameter. 326 if (segments[numThreadsClausePos] && 327 parser.resolveOperand(numThreads.first, numThreads.second, 328 result.operands)) 329 return failure(); 330 331 // Add private parameters. 332 if (segments[privateClausePos] && 333 parser.resolveOperands(privates, privateTypes, privates[0].location, 334 result.operands)) 335 return failure(); 336 337 // Add firstprivate parameters. 338 if (segments[firstprivateClausePos] && 339 parser.resolveOperands(firstprivates, firstprivateTypes, 340 firstprivates[0].location, result.operands)) 341 return failure(); 342 343 // Add shared parameters. 344 if (segments[sharedClausePos] && 345 parser.resolveOperands(shareds, sharedTypes, shareds[0].location, 346 result.operands)) 347 return failure(); 348 349 // Add copyin parameters. 350 if (segments[copyinClausePos] && 351 parser.resolveOperands(copyins, copyinTypes, copyins[0].location, 352 result.operands)) 353 return failure(); 354 355 // Add allocate parameters. 356 if (segments[allocateClausePos] && 357 parser.resolveOperands(allocates, allocateTypes, allocates[0].location, 358 result.operands)) 359 return failure(); 360 361 // Add allocator parameters. 362 if (segments[allocatorPos] && 363 parser.resolveOperands(allocators, allocatorTypes, allocators[0].location, 364 result.operands)) 365 return failure(); 366 367 result.addAttribute("operand_segment_sizes", 368 parser.getBuilder().getI32VectorAttr(segments)); 369 370 Region *body = result.addRegion(); 371 SmallVector<OpAsmParser::OperandType, 4> regionArgs; 372 SmallVector<Type, 4> regionArgTypes; 373 if (parser.parseRegion(*body, regionArgs, regionArgTypes)) 374 return failure(); 375 return success(); 376 } 377 378 /// linear ::= `linear` `(` linear-list `)` 379 /// linear-list := linear-val | linear-val linear-list 380 /// linear-val := ssa-id-and-type `=` ssa-id-and-type 381 static ParseResult 382 parseLinearClause(OpAsmParser &parser, 383 SmallVectorImpl<OpAsmParser::OperandType> &vars, 384 SmallVectorImpl<Type> &types, 385 SmallVectorImpl<OpAsmParser::OperandType> &stepVars) { 386 if (parser.parseLParen()) 387 return failure(); 388 389 do { 390 OpAsmParser::OperandType var; 391 Type type; 392 OpAsmParser::OperandType stepVar; 393 if (parser.parseOperand(var) || parser.parseEqual() || 394 parser.parseOperand(stepVar) || parser.parseColonType(type)) 395 return failure(); 396 397 vars.push_back(var); 398 types.push_back(type); 399 stepVars.push_back(stepVar); 400 } while (succeeded(parser.parseOptionalComma())); 401 402 if (parser.parseRParen()) 403 return failure(); 404 405 return success(); 406 } 407 408 /// schedule ::= `schedule` `(` sched-list `)` 409 /// sched-list ::= sched-val | sched-val sched-list 410 /// sched-val ::= sched-with-chunk | sched-wo-chunk 411 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)? 412 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided` 413 /// sched-wo-chunk ::= `auto` | `runtime` 414 static ParseResult 415 parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule, 416 Optional<OpAsmParser::OperandType> &chunkSize) { 417 if (parser.parseLParen()) 418 return failure(); 419 420 StringRef keyword; 421 if (parser.parseKeyword(&keyword)) 422 return failure(); 423 424 schedule = keyword; 425 if (keyword == "static" || keyword == "dynamic" || keyword == "guided") { 426 if (succeeded(parser.parseOptionalEqual())) { 427 chunkSize = OpAsmParser::OperandType{}; 428 if (parser.parseOperand(*chunkSize)) 429 return failure(); 430 } else { 431 chunkSize = llvm::NoneType::None; 432 } 433 } else if (keyword == "auto" || keyword == "runtime") { 434 chunkSize = llvm::NoneType::None; 435 } else { 436 return parser.emitError(parser.getNameLoc()) << " expected schedule kind"; 437 } 438 439 if (parser.parseRParen()) 440 return failure(); 441 442 return success(); 443 } 444 445 /// reduction-init ::= `reduction` `(` reduction-entry-list `)` 446 /// reduction-entry-list ::= reduction-entry 447 /// | reduction-entry-list `,` reduction-entry 448 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type 449 static ParseResult 450 parseReductionVarList(OpAsmParser &parser, 451 SmallVectorImpl<SymbolRefAttr> &symbols, 452 SmallVectorImpl<OpAsmParser::OperandType> &operands, 453 SmallVectorImpl<Type> &types) { 454 if (failed(parser.parseLParen())) 455 return failure(); 456 457 do { 458 if (parser.parseAttribute(symbols.emplace_back()) || parser.parseArrow() || 459 parser.parseOperand(operands.emplace_back()) || 460 parser.parseColonType(types.emplace_back())) 461 return failure(); 462 } while (succeeded(parser.parseOptionalComma())); 463 return parser.parseRParen(); 464 } 465 466 /// Parses an OpenMP Workshare Loop operation 467 /// 468 /// operation ::= `omp.wsloop` loop-control clause-list 469 /// loop-control ::= `(` ssa-id-list `)` `:` type `=` loop-bounds 470 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` steps 471 /// steps := `step` `(`ssa-id-list`)` 472 /// clause-list ::= clause | empty | clause-list 473 /// clause ::= private | firstprivate | lastprivate | linear | schedule | 474 // collapse | nowait | ordered | order | inclusive 475 /// private ::= `private` `(` ssa-id-and-type-list `)` 476 /// firstprivate ::= `firstprivate` `(` ssa-id-and-type-list `)` 477 /// lastprivate ::= `lastprivate` `(` ssa-id-and-type-list `)` 478 /// linear ::= `linear` `(` linear-list `)` 479 /// schedule ::= `schedule` `(` sched-list `)` 480 /// collapse ::= `collapse` `(` ssa-id-and-type `)` 481 /// nowait ::= `nowait` 482 /// ordered ::= `ordered` `(` ssa-id-and-type `)` 483 /// order ::= `order` `(` `concurrent` `)` 484 /// inclusive ::= `inclusive` 485 /// 486 static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) { 487 Type loopVarType; 488 int numIVs; 489 490 // Parse an opening `(` followed by induction variables followed by `)` 491 SmallVector<OpAsmParser::OperandType> ivs; 492 if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1, 493 OpAsmParser::Delimiter::Paren)) 494 return failure(); 495 496 numIVs = static_cast<int>(ivs.size()); 497 498 if (parser.parseColonType(loopVarType)) 499 return failure(); 500 501 // Parse loop bounds. 502 SmallVector<OpAsmParser::OperandType> lower; 503 if (parser.parseEqual() || 504 parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) || 505 parser.resolveOperands(lower, loopVarType, result.operands)) 506 return failure(); 507 508 SmallVector<OpAsmParser::OperandType> upper; 509 if (parser.parseKeyword("to") || 510 parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) || 511 parser.resolveOperands(upper, loopVarType, result.operands)) 512 return failure(); 513 514 // Parse step values. 515 SmallVector<OpAsmParser::OperandType> steps; 516 if (parser.parseKeyword("step") || 517 parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) || 518 parser.resolveOperands(steps, loopVarType, result.operands)) 519 return failure(); 520 521 SmallVector<OpAsmParser::OperandType> privates; 522 SmallVector<Type> privateTypes; 523 SmallVector<OpAsmParser::OperandType> firstprivates; 524 SmallVector<Type> firstprivateTypes; 525 SmallVector<OpAsmParser::OperandType> lastprivates; 526 SmallVector<Type> lastprivateTypes; 527 SmallVector<OpAsmParser::OperandType> linears; 528 SmallVector<Type> linearTypes; 529 SmallVector<OpAsmParser::OperandType> linearSteps; 530 SmallVector<SymbolRefAttr> reductionSymbols; 531 SmallVector<OpAsmParser::OperandType> reductionVars; 532 SmallVector<Type> reductionVarTypes; 533 SmallString<8> schedule; 534 Optional<OpAsmParser::OperandType> scheduleChunkSize; 535 536 const StringRef opName = result.name.getStringRef(); 537 StringRef keyword; 538 539 enum SegmentPos { 540 lbPos = 0, 541 ubPos, 542 stepPos, 543 privateClausePos, 544 firstprivateClausePos, 545 lastprivateClausePos, 546 linearClausePos, 547 linearStepPos, 548 reductionVarPos, 549 scheduleClausePos, 550 }; 551 std::array<int, 10> segments{numIVs, numIVs, numIVs, 0, 0, 0, 0, 0, 0, 0}; 552 553 while (succeeded(parser.parseOptionalKeyword(&keyword))) { 554 if (keyword == "private") { 555 if (segments[privateClausePos]) 556 return allowedOnce(parser, "private", opName); 557 if (parseOperandAndTypeList(parser, privates, privateTypes)) 558 return failure(); 559 segments[privateClausePos] = privates.size(); 560 } else if (keyword == "firstprivate") { 561 // fail if there was already another firstprivate clause 562 if (segments[firstprivateClausePos]) 563 return allowedOnce(parser, "firstprivate", opName); 564 if (parseOperandAndTypeList(parser, firstprivates, firstprivateTypes)) 565 return failure(); 566 segments[firstprivateClausePos] = firstprivates.size(); 567 } else if (keyword == "lastprivate") { 568 // fail if there was already another shared clause 569 if (segments[lastprivateClausePos]) 570 return allowedOnce(parser, "lastprivate", opName); 571 if (parseOperandAndTypeList(parser, lastprivates, lastprivateTypes)) 572 return failure(); 573 segments[lastprivateClausePos] = lastprivates.size(); 574 } else if (keyword == "linear") { 575 // fail if there was already another linear clause 576 if (segments[linearClausePos]) 577 return allowedOnce(parser, "linear", opName); 578 if (parseLinearClause(parser, linears, linearTypes, linearSteps)) 579 return failure(); 580 segments[linearClausePos] = linears.size(); 581 segments[linearStepPos] = linearSteps.size(); 582 } else if (keyword == "schedule") { 583 if (!schedule.empty()) 584 return allowedOnce(parser, "schedule", opName); 585 if (parseScheduleClause(parser, schedule, scheduleChunkSize)) 586 return failure(); 587 if (scheduleChunkSize) { 588 segments[scheduleClausePos] = 1; 589 } 590 } else if (keyword == "collapse") { 591 auto type = parser.getBuilder().getI64Type(); 592 mlir::IntegerAttr attr; 593 if (parser.parseLParen() || parser.parseAttribute(attr, type) || 594 parser.parseRParen()) 595 return failure(); 596 result.addAttribute("collapse_val", attr); 597 } else if (keyword == "nowait") { 598 auto attr = UnitAttr::get(parser.getContext()); 599 result.addAttribute("nowait", attr); 600 } else if (keyword == "ordered") { 601 mlir::IntegerAttr attr; 602 if (succeeded(parser.parseOptionalLParen())) { 603 auto type = parser.getBuilder().getI64Type(); 604 if (parser.parseAttribute(attr, type)) 605 return failure(); 606 if (parser.parseRParen()) 607 return failure(); 608 } else { 609 // Use 0 to represent no ordered parameter was specified 610 attr = parser.getBuilder().getI64IntegerAttr(0); 611 } 612 result.addAttribute("ordered_val", attr); 613 } else if (keyword == "order") { 614 StringRef order; 615 if (parser.parseLParen() || parser.parseKeyword(&order) || 616 parser.parseRParen()) 617 return failure(); 618 auto attr = parser.getBuilder().getStringAttr(order); 619 result.addAttribute("order", attr); 620 } else if (keyword == "inclusive") { 621 auto attr = UnitAttr::get(parser.getContext()); 622 result.addAttribute("inclusive", attr); 623 } else if (keyword == "reduction") { 624 if (segments[reductionVarPos]) 625 return allowedOnce(parser, "reduction", opName); 626 if (failed(parseReductionVarList(parser, reductionSymbols, reductionVars, 627 reductionVarTypes))) 628 return failure(); 629 segments[reductionVarPos] = reductionVars.size(); 630 } 631 } 632 633 if (segments[privateClausePos]) { 634 parser.resolveOperands(privates, privateTypes, privates[0].location, 635 result.operands); 636 } 637 638 if (segments[firstprivateClausePos]) { 639 parser.resolveOperands(firstprivates, firstprivateTypes, 640 firstprivates[0].location, result.operands); 641 } 642 643 if (segments[lastprivateClausePos]) { 644 parser.resolveOperands(lastprivates, lastprivateTypes, 645 lastprivates[0].location, result.operands); 646 } 647 648 if (segments[linearClausePos]) { 649 parser.resolveOperands(linears, linearTypes, linears[0].location, 650 result.operands); 651 auto linearStepType = parser.getBuilder().getI32Type(); 652 SmallVector<Type> linearStepTypes(linearSteps.size(), linearStepType); 653 parser.resolveOperands(linearSteps, linearStepTypes, 654 linearSteps[0].location, result.operands); 655 } 656 657 if (segments[reductionVarPos]) { 658 if (failed(parser.resolveOperands(reductionVars, reductionVarTypes, 659 parser.getNameLoc(), result.operands))) { 660 return failure(); 661 } 662 SmallVector<Attribute> reductions(reductionSymbols.begin(), 663 reductionSymbols.end()); 664 result.addAttribute("reductions", 665 parser.getBuilder().getArrayAttr(reductions)); 666 } 667 668 if (!schedule.empty()) { 669 schedule[0] = llvm::toUpper(schedule[0]); 670 auto attr = parser.getBuilder().getStringAttr(schedule); 671 result.addAttribute("schedule_val", attr); 672 if (scheduleChunkSize) { 673 auto chunkSizeType = parser.getBuilder().getI32Type(); 674 parser.resolveOperand(*scheduleChunkSize, chunkSizeType, result.operands); 675 } 676 } 677 678 result.addAttribute("operand_segment_sizes", 679 parser.getBuilder().getI32VectorAttr(segments)); 680 681 // Now parse the body. 682 Region *body = result.addRegion(); 683 SmallVector<Type> ivTypes(numIVs, loopVarType); 684 SmallVector<OpAsmParser::OperandType> blockArgs(ivs); 685 if (parser.parseRegion(*body, blockArgs, ivTypes)) 686 return failure(); 687 return success(); 688 } 689 690 static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) { 691 auto args = op.getRegion().front().getArguments(); 692 p << " (" << args << ") : " << args[0].getType() << " = (" << op.lowerBound() 693 << ") to (" << op.upperBound() << ") step (" << op.step() << ")"; 694 695 // Print private, firstprivate, shared and copyin parameters 696 auto printDataVars = [&p](StringRef name, OperandRange vars) { 697 if (vars.empty()) 698 return; 699 700 p << " " << name << "("; 701 llvm::interleaveComma( 702 vars, p, [&](const Value &v) { p << v << " : " << v.getType(); }); 703 p << ")"; 704 }; 705 printDataVars("private", op.private_vars()); 706 printDataVars("firstprivate", op.firstprivate_vars()); 707 printDataVars("lastprivate", op.lastprivate_vars()); 708 709 auto linearVars = op.linear_vars(); 710 auto linearVarsSize = linearVars.size(); 711 if (linearVarsSize) { 712 p << " " 713 << "linear" 714 << "("; 715 for (unsigned i = 0; i < linearVarsSize; ++i) { 716 std::string separator = i == linearVarsSize - 1 ? ")" : ", "; 717 p << linearVars[i]; 718 if (op.linear_step_vars().size() > i) 719 p << " = " << op.linear_step_vars()[i]; 720 p << " : " << linearVars[i].getType() << separator; 721 } 722 } 723 724 if (auto sched = op.schedule_val()) { 725 auto schedLower = sched->lower(); 726 p << " schedule(" << schedLower; 727 if (auto chunk = op.schedule_chunk_var()) { 728 p << " = " << chunk; 729 } 730 p << ")"; 731 } 732 733 if (auto collapse = op.collapse_val()) 734 p << " collapse(" << collapse << ")"; 735 736 if (op.nowait()) 737 p << " nowait"; 738 739 if (auto ordered = op.ordered_val()) { 740 p << " ordered(" << ordered << ")"; 741 } 742 743 if (!op.reduction_vars().empty()) { 744 p << " reduction("; 745 for (unsigned i = 0, e = op.getNumReductionVars(); i < e; ++i) { 746 if (i != 0) 747 p << ", "; 748 p << (*op.reductions())[i] << " -> " << op.reduction_vars()[i] << " : " 749 << op.reduction_vars()[i].getType(); 750 } 751 p << ")"; 752 } 753 754 if (op.inclusive()) { 755 p << " inclusive"; 756 } 757 758 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 759 } 760 761 //===----------------------------------------------------------------------===// 762 // ReductionOp 763 //===----------------------------------------------------------------------===// 764 765 static ParseResult parseAtomicReductionRegion(OpAsmParser &parser, 766 Region ®ion) { 767 if (parser.parseOptionalKeyword("atomic")) 768 return success(); 769 return parser.parseRegion(region); 770 } 771 772 static void printAtomicReductionRegion(OpAsmPrinter &printer, 773 ReductionDeclareOp op, Region ®ion) { 774 if (region.empty()) 775 return; 776 printer << "atomic "; 777 printer.printRegion(region); 778 } 779 780 static LogicalResult verifyReductionDeclareOp(ReductionDeclareOp op) { 781 if (op.initializerRegion().empty()) 782 return op.emitOpError() << "expects non-empty initializer region"; 783 Block &initializerEntryBlock = op.initializerRegion().front(); 784 if (initializerEntryBlock.getNumArguments() != 1 || 785 initializerEntryBlock.getArgument(0).getType() != op.type()) { 786 return op.emitOpError() << "expects initializer region with one argument " 787 "of the reduction type"; 788 } 789 790 for (YieldOp yieldOp : op.initializerRegion().getOps<YieldOp>()) { 791 if (yieldOp.results().size() != 1 || 792 yieldOp.results().getTypes()[0] != op.type()) 793 return op.emitOpError() << "expects initializer region to yield a value " 794 "of the reduction type"; 795 } 796 797 if (op.reductionRegion().empty()) 798 return op.emitOpError() << "expects non-empty reduction region"; 799 Block &reductionEntryBlock = op.reductionRegion().front(); 800 if (reductionEntryBlock.getNumArguments() != 2 || 801 reductionEntryBlock.getArgumentTypes()[0] != 802 reductionEntryBlock.getArgumentTypes()[1] || 803 reductionEntryBlock.getArgumentTypes()[0] != op.type()) 804 return op.emitOpError() << "expects reduction region with two arguments of " 805 "the reduction type"; 806 for (YieldOp yieldOp : op.reductionRegion().getOps<YieldOp>()) { 807 if (yieldOp.results().size() != 1 || 808 yieldOp.results().getTypes()[0] != op.type()) 809 return op.emitOpError() << "expects reduction region to yield a value " 810 "of the reduction type"; 811 } 812 813 if (op.atomicReductionRegion().empty()) 814 return success(); 815 816 Block &atomicReductionEntryBlock = op.atomicReductionRegion().front(); 817 if (atomicReductionEntryBlock.getNumArguments() != 2 || 818 atomicReductionEntryBlock.getArgumentTypes()[0] != 819 atomicReductionEntryBlock.getArgumentTypes()[1]) 820 return op.emitOpError() << "expects atomic reduction region with two " 821 "arguments of the same type"; 822 auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0] 823 .dyn_cast<PointerLikeType>(); 824 if (!ptrType || ptrType.getElementType() != op.type()) 825 return op.emitOpError() << "expects atomic reduction region arguments to " 826 "be accumulators containing the reduction type"; 827 return success(); 828 } 829 830 static LogicalResult verifyReductionOp(ReductionOp op) { 831 // TODO: generalize this to an op interface when there is more than one op 832 // that supports reductions. 833 auto container = op->getParentOfType<WsLoopOp>(); 834 for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i) 835 if (container.reduction_vars()[i] == op.accumulator()) 836 return success(); 837 838 return op.emitOpError() << "the accumulator is not used by the parent"; 839 } 840 841 //===----------------------------------------------------------------------===// 842 // WsLoopOp 843 //===----------------------------------------------------------------------===// 844 845 void WsLoopOp::build(OpBuilder &builder, OperationState &state, 846 ValueRange lowerBound, ValueRange upperBound, 847 ValueRange step, ArrayRef<NamedAttribute> attributes) { 848 build(builder, state, TypeRange(), lowerBound, upperBound, step, 849 /*private_vars=*/ValueRange(), 850 /*firstprivate_vars=*/ValueRange(), /*lastprivate_vars=*/ValueRange(), 851 /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(), 852 /*reduction_vars=*/ValueRange(), /*schedule_val=*/nullptr, 853 /*schedule_chunk_var=*/nullptr, /*collapse_val=*/nullptr, 854 /*nowait=*/nullptr, /*ordered_val=*/nullptr, /*order_val=*/nullptr, 855 /*inclusive=*/nullptr, /*buildBody=*/false); 856 state.addAttributes(attributes); 857 } 858 859 void WsLoopOp::build(OpBuilder &, OperationState &state, TypeRange resultTypes, 860 ValueRange operands, ArrayRef<NamedAttribute> attributes) { 861 state.addOperands(operands); 862 state.addAttributes(attributes); 863 (void)state.addRegion(); 864 assert(resultTypes.empty() && "mismatched number of return types"); 865 state.addTypes(resultTypes); 866 } 867 868 void WsLoopOp::build(OpBuilder &builder, OperationState &result, 869 TypeRange typeRange, ValueRange lowerBounds, 870 ValueRange upperBounds, ValueRange steps, 871 ValueRange privateVars, ValueRange firstprivateVars, 872 ValueRange lastprivateVars, ValueRange linearVars, 873 ValueRange linearStepVars, ValueRange reductionVars, 874 StringAttr scheduleVal, Value scheduleChunkVar, 875 IntegerAttr collapseVal, UnitAttr nowait, 876 IntegerAttr orderedVal, StringAttr orderVal, 877 UnitAttr inclusive, bool buildBody) { 878 result.addOperands(lowerBounds); 879 result.addOperands(upperBounds); 880 result.addOperands(steps); 881 result.addOperands(privateVars); 882 result.addOperands(firstprivateVars); 883 result.addOperands(linearVars); 884 result.addOperands(linearStepVars); 885 if (scheduleChunkVar) 886 result.addOperands(scheduleChunkVar); 887 888 if (scheduleVal) 889 result.addAttribute("schedule_val", scheduleVal); 890 if (collapseVal) 891 result.addAttribute("collapse_val", collapseVal); 892 if (nowait) 893 result.addAttribute("nowait", nowait); 894 if (orderedVal) 895 result.addAttribute("ordered_val", orderedVal); 896 if (orderVal) 897 result.addAttribute("order", orderVal); 898 if (inclusive) 899 result.addAttribute("inclusive", inclusive); 900 result.addAttribute( 901 WsLoopOp::getOperandSegmentSizeAttr(), 902 builder.getI32VectorAttr( 903 {static_cast<int32_t>(lowerBounds.size()), 904 static_cast<int32_t>(upperBounds.size()), 905 static_cast<int32_t>(steps.size()), 906 static_cast<int32_t>(privateVars.size()), 907 static_cast<int32_t>(firstprivateVars.size()), 908 static_cast<int32_t>(lastprivateVars.size()), 909 static_cast<int32_t>(linearVars.size()), 910 static_cast<int32_t>(linearStepVars.size()), 911 static_cast<int32_t>(reductionVars.size()), 912 static_cast<int32_t>(scheduleChunkVar != nullptr ? 1 : 0)})); 913 914 Region *bodyRegion = result.addRegion(); 915 if (buildBody) { 916 OpBuilder::InsertionGuard guard(builder); 917 unsigned numIVs = steps.size(); 918 SmallVector<Type, 8> argTypes(numIVs, steps.getType().front()); 919 builder.createBlock(bodyRegion, {}, argTypes); 920 } 921 } 922 923 static LogicalResult verifyWsLoopOp(WsLoopOp op) { 924 if (op.getNumReductionVars() != 0) { 925 if (!op.reductions() || 926 op.reductions()->size() != op.getNumReductionVars()) { 927 return op.emitOpError() << "expected as many reduction symbol references " 928 "as reduction variables"; 929 } 930 } else { 931 if (op.reductions()) 932 return op.emitOpError() << "unexpected reduction symbol references"; 933 return success(); 934 } 935 936 DenseSet<Value> accumulators; 937 for (auto args : llvm::zip(op.reduction_vars(), *op.reductions())) { 938 Value accum = std::get<0>(args); 939 if (!accumulators.insert(accum).second) { 940 return op.emitOpError() << "accumulator variable used more than once"; 941 } 942 Type varType = accum.getType().cast<PointerLikeType>(); 943 auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>(); 944 auto decl = 945 SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef); 946 if (!decl) { 947 return op.emitOpError() << "expected symbol reference " << symbolRef 948 << " to point to a reduction declaration"; 949 } 950 951 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType) { 952 return op.emitOpError() 953 << "expected accumulator (" << varType 954 << ") to be the same type as reduction declaration (" 955 << decl.getAccumulatorType() << ")"; 956 } 957 } 958 959 return success(); 960 } 961 962 //===----------------------------------------------------------------------===// 963 // Parser, printer and verifier for Synchronization Hint (2.17.12) 964 //===----------------------------------------------------------------------===// 965 966 /// Parses a Synchronization Hint clause. The value of hint is an integer 967 /// which is a combination of different hints from `omp_sync_hint_t`. 968 /// 969 /// hint-clause = `hint` `(` hint-value `)` 970 static ParseResult parseSynchronizationHint(OpAsmParser &parser, 971 IntegerAttr &hintAttr) { 972 if (failed(parser.parseOptionalKeyword("hint"))) { 973 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0); 974 return success(); 975 } 976 977 if (failed(parser.parseLParen())) 978 return failure(); 979 StringRef hintKeyword; 980 int64_t hint = 0; 981 do { 982 if (failed(parser.parseKeyword(&hintKeyword))) 983 return failure(); 984 if (hintKeyword == "uncontended") 985 hint |= 1; 986 else if (hintKeyword == "contended") 987 hint |= 2; 988 else if (hintKeyword == "nonspeculative") 989 hint |= 4; 990 else if (hintKeyword == "speculative") 991 hint |= 8; 992 else 993 return parser.emitError(parser.getCurrentLocation()) 994 << hintKeyword << " is not a valid hint"; 995 } while (succeeded(parser.parseOptionalComma())); 996 if (failed(parser.parseRParen())) 997 return failure(); 998 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint); 999 return success(); 1000 } 1001 1002 /// Prints a Synchronization Hint clause 1003 static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, 1004 IntegerAttr hintAttr) { 1005 int64_t hint = hintAttr.getInt(); 1006 1007 if (hint == 0) 1008 return; 1009 1010 // Helper function to get n-th bit from the right end of `value` 1011 auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; 1012 1013 bool uncontended = bitn(hint, 0); 1014 bool contended = bitn(hint, 1); 1015 bool nonspeculative = bitn(hint, 2); 1016 bool speculative = bitn(hint, 3); 1017 1018 SmallVector<StringRef> hints; 1019 if (uncontended) 1020 hints.push_back("uncontended"); 1021 if (contended) 1022 hints.push_back("contended"); 1023 if (nonspeculative) 1024 hints.push_back("nonspeculative"); 1025 if (speculative) 1026 hints.push_back("speculative"); 1027 1028 p << "hint("; 1029 llvm::interleaveComma(hints, p); 1030 p << ")"; 1031 } 1032 1033 /// Verifies a synchronization hint clause 1034 static LogicalResult verifySynchronizationHint(Operation *op, int32_t hint) { 1035 1036 // Helper function to get n-th bit from the right end of `value` 1037 auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; 1038 1039 bool uncontended = bitn(hint, 0); 1040 bool contended = bitn(hint, 1); 1041 bool nonspeculative = bitn(hint, 2); 1042 bool speculative = bitn(hint, 3); 1043 1044 if (uncontended && contended) 1045 return op->emitOpError() << "the hints omp_sync_hint_uncontended and " 1046 "omp_sync_hint_contended cannot be combined"; 1047 if (nonspeculative && speculative) 1048 return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and " 1049 "omp_sync_hint_speculative cannot be combined."; 1050 return success(); 1051 } 1052 1053 //===----------------------------------------------------------------------===// 1054 // Verifier for critical construct (2.17.1) 1055 //===----------------------------------------------------------------------===// 1056 1057 static LogicalResult verifyCriticalOp(CriticalOp op) { 1058 1059 if (failed(verifySynchronizationHint(op, op.hint()))) { 1060 return failure(); 1061 } 1062 if (!op.name().hasValue() && (op.hint() != 0)) 1063 return op.emitOpError() << "must specify a name unless the effect is as if " 1064 "no hint is specified"; 1065 1066 if (op.nameAttr()) { 1067 auto symbolRef = op.nameAttr().cast<SymbolRefAttr>(); 1068 auto decl = 1069 SymbolTable::lookupNearestSymbolFrom<CriticalDeclareOp>(op, symbolRef); 1070 if (!decl) { 1071 return op.emitOpError() << "expected symbol reference " << symbolRef 1072 << " to point to a critical declaration"; 1073 } 1074 } 1075 1076 return success(); 1077 } 1078 1079 #define GET_OP_CLASSES 1080 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 1081