1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the OpenMP dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/IR/Attributes.h" 16 #include "mlir/IR/OpImplementation.h" 17 #include "mlir/IR/OperationSupport.h" 18 19 #include "llvm/ADT/SmallString.h" 20 #include "llvm/ADT/StringExtras.h" 21 #include "llvm/ADT/StringRef.h" 22 #include "llvm/ADT/StringSwitch.h" 23 #include <cstddef> 24 25 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc" 26 27 using namespace mlir; 28 using namespace mlir::omp; 29 30 void OpenMPDialect::initialize() { 31 addOperations< 32 #define GET_OP_LIST 33 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 34 >(); 35 } 36 37 //===----------------------------------------------------------------------===// 38 // ParallelOp 39 //===----------------------------------------------------------------------===// 40 41 void ParallelOp::build(OpBuilder &builder, OperationState &state, 42 ArrayRef<NamedAttribute> attributes) { 43 ParallelOp::build( 44 builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr, 45 /*default_val=*/nullptr, /*private_vars=*/ValueRange(), 46 /*firstprivate_vars=*/ValueRange(), /*shared_vars=*/ValueRange(), 47 /*copyin_vars=*/ValueRange(), /*allocate_vars=*/ValueRange(), 48 /*allocators_vars=*/ValueRange(), /*proc_bind_val=*/nullptr); 49 state.addAttributes(attributes); 50 } 51 52 /// Parse a list of operands with types. 53 /// 54 /// operand-and-type-list ::= `(` ssa-id-and-type-list `)` 55 /// ssa-id-and-type-list ::= ssa-id-and-type | 56 /// ssa-id-and-type `,` ssa-id-and-type-list 57 /// ssa-id-and-type ::= ssa-id `:` type 58 static ParseResult 59 parseOperandAndTypeList(OpAsmParser &parser, 60 SmallVectorImpl<OpAsmParser::OperandType> &operands, 61 SmallVectorImpl<Type> &types) { 62 if (parser.parseLParen()) 63 return failure(); 64 65 do { 66 OpAsmParser::OperandType operand; 67 Type type; 68 if (parser.parseOperand(operand) || parser.parseColonType(type)) 69 return failure(); 70 operands.push_back(operand); 71 types.push_back(type); 72 } while (succeeded(parser.parseOptionalComma())); 73 74 if (parser.parseRParen()) 75 return failure(); 76 77 return success(); 78 } 79 80 /// Parse an allocate clause with allocators and a list of operands with types. 81 /// 82 /// operand-and-type-list ::= `(` allocate-operand-list `)` 83 /// allocate-operand-list :: = allocate-operand | 84 /// allocator-operand `,` allocate-operand-list 85 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type 86 /// ssa-id-and-type ::= ssa-id `:` type 87 static ParseResult parseAllocateAndAllocator( 88 OpAsmParser &parser, 89 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate, 90 SmallVectorImpl<Type> &typesAllocate, 91 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator, 92 SmallVectorImpl<Type> &typesAllocator) { 93 if (parser.parseLParen()) 94 return failure(); 95 96 do { 97 OpAsmParser::OperandType operand; 98 Type type; 99 100 if (parser.parseOperand(operand) || parser.parseColonType(type)) 101 return failure(); 102 operandsAllocator.push_back(operand); 103 typesAllocator.push_back(type); 104 if (parser.parseArrow()) 105 return failure(); 106 if (parser.parseOperand(operand) || parser.parseColonType(type)) 107 return failure(); 108 109 operandsAllocate.push_back(operand); 110 typesAllocate.push_back(type); 111 } while (succeeded(parser.parseOptionalComma())); 112 113 if (parser.parseRParen()) 114 return failure(); 115 116 return success(); 117 } 118 119 static LogicalResult verifyParallelOp(ParallelOp op) { 120 if (op.allocate_vars().size() != op.allocators_vars().size()) 121 return op.emitError( 122 "expected equal sizes for allocate and allocator variables"); 123 return success(); 124 } 125 126 static void printParallelOp(OpAsmPrinter &p, ParallelOp op) { 127 p << "omp.parallel"; 128 129 if (auto ifCond = op.if_expr_var()) 130 p << " if(" << ifCond << " : " << ifCond.getType() << ")"; 131 132 if (auto threads = op.num_threads_var()) 133 p << " num_threads(" << threads << " : " << threads.getType() << ")"; 134 135 // Print private, firstprivate, shared and copyin parameters 136 auto printDataVars = [&p](StringRef name, OperandRange vars) { 137 if (vars.size()) { 138 p << " " << name << "("; 139 for (unsigned i = 0; i < vars.size(); ++i) { 140 std::string separator = i == vars.size() - 1 ? ")" : ", "; 141 p << vars[i] << " : " << vars[i].getType() << separator; 142 } 143 } 144 }; 145 146 // Print allocator and allocate parameters 147 auto printAllocateAndAllocator = [&p](OperandRange varsAllocate, 148 OperandRange varsAllocator) { 149 if (varsAllocate.empty()) 150 return; 151 152 p << " allocate("; 153 for (unsigned i = 0; i < varsAllocate.size(); ++i) { 154 std::string separator = i == varsAllocate.size() - 1 ? ")" : ", "; 155 p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> "; 156 p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator; 157 } 158 }; 159 160 printDataVars("private", op.private_vars()); 161 printDataVars("firstprivate", op.firstprivate_vars()); 162 printDataVars("shared", op.shared_vars()); 163 printDataVars("copyin", op.copyin_vars()); 164 printAllocateAndAllocator(op.allocate_vars(), op.allocators_vars()); 165 166 if (auto def = op.default_val()) 167 p << " default(" << def->drop_front(3) << ")"; 168 169 if (auto bind = op.proc_bind_val()) 170 p << " proc_bind(" << bind << ")"; 171 172 p.printRegion(op.getRegion()); 173 } 174 175 /// Emit an error if the same clause is present more than once on an operation. 176 static ParseResult allowedOnce(OpAsmParser &parser, StringRef clause, 177 StringRef operation) { 178 return parser.emitError(parser.getNameLoc()) 179 << " at most one " << clause << " clause can appear on the " 180 << operation << " operation"; 181 } 182 183 /// Parses a parallel operation. 184 /// 185 /// operation ::= `omp.parallel` clause-list 186 /// clause-list ::= clause | clause clause-list 187 /// clause ::= if | numThreads | private | firstprivate | shared | copyin | 188 /// default | procBind 189 /// if ::= `if` `(` ssa-id `)` 190 /// numThreads ::= `num_threads` `(` ssa-id-and-type `)` 191 /// private ::= `private` operand-and-type-list 192 /// firstprivate ::= `firstprivate` operand-and-type-list 193 /// shared ::= `shared` operand-and-type-list 194 /// copyin ::= `copyin` operand-and-type-list 195 /// allocate ::= `allocate` operand-and-type `->` operand-and-type-list 196 /// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`) 197 /// procBind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)` 198 /// 199 /// Note that each clause can only appear once in the clase-list. 200 static ParseResult parseParallelOp(OpAsmParser &parser, 201 OperationState &result) { 202 std::pair<OpAsmParser::OperandType, Type> ifCond; 203 std::pair<OpAsmParser::OperandType, Type> numThreads; 204 SmallVector<OpAsmParser::OperandType, 4> privates; 205 SmallVector<Type, 4> privateTypes; 206 SmallVector<OpAsmParser::OperandType, 4> firstprivates; 207 SmallVector<Type, 4> firstprivateTypes; 208 SmallVector<OpAsmParser::OperandType, 4> shareds; 209 SmallVector<Type, 4> sharedTypes; 210 SmallVector<OpAsmParser::OperandType, 4> copyins; 211 SmallVector<Type, 4> copyinTypes; 212 SmallVector<OpAsmParser::OperandType, 4> allocates; 213 SmallVector<Type, 4> allocateTypes; 214 SmallVector<OpAsmParser::OperandType, 4> allocators; 215 SmallVector<Type, 4> allocatorTypes; 216 std::array<int, 8> segments{0, 0, 0, 0, 0, 0, 0, 0}; 217 StringRef keyword; 218 bool defaultVal = false; 219 bool procBind = false; 220 221 const int ifClausePos = 0; 222 const int numThreadsClausePos = 1; 223 const int privateClausePos = 2; 224 const int firstprivateClausePos = 3; 225 const int sharedClausePos = 4; 226 const int copyinClausePos = 5; 227 const int allocateClausePos = 6; 228 const int allocatorPos = 7; 229 const StringRef opName = result.name.getStringRef(); 230 231 while (succeeded(parser.parseOptionalKeyword(&keyword))) { 232 if (keyword == "if") { 233 // Fail if there was already another if condition. 234 if (segments[ifClausePos]) 235 return allowedOnce(parser, "if", opName); 236 if (parser.parseLParen() || parser.parseOperand(ifCond.first) || 237 parser.parseColonType(ifCond.second) || parser.parseRParen()) 238 return failure(); 239 segments[ifClausePos] = 1; 240 } else if (keyword == "num_threads") { 241 // Fail if there was already another num_threads clause. 242 if (segments[numThreadsClausePos]) 243 return allowedOnce(parser, "num_threads", opName); 244 if (parser.parseLParen() || parser.parseOperand(numThreads.first) || 245 parser.parseColonType(numThreads.second) || parser.parseRParen()) 246 return failure(); 247 segments[numThreadsClausePos] = 1; 248 } else if (keyword == "private") { 249 // Fail if there was already another private clause. 250 if (segments[privateClausePos]) 251 return allowedOnce(parser, "private", opName); 252 if (parseOperandAndTypeList(parser, privates, privateTypes)) 253 return failure(); 254 segments[privateClausePos] = privates.size(); 255 } else if (keyword == "firstprivate") { 256 // Fail if there was already another firstprivate clause. 257 if (segments[firstprivateClausePos]) 258 return allowedOnce(parser, "firstprivate", opName); 259 if (parseOperandAndTypeList(parser, firstprivates, firstprivateTypes)) 260 return failure(); 261 segments[firstprivateClausePos] = firstprivates.size(); 262 } else if (keyword == "shared") { 263 // Fail if there was already another shared clause. 264 if (segments[sharedClausePos]) 265 return allowedOnce(parser, "shared", opName); 266 if (parseOperandAndTypeList(parser, shareds, sharedTypes)) 267 return failure(); 268 segments[sharedClausePos] = shareds.size(); 269 } else if (keyword == "copyin") { 270 // Fail if there was already another copyin clause. 271 if (segments[copyinClausePos]) 272 return allowedOnce(parser, "copyin", opName); 273 if (parseOperandAndTypeList(parser, copyins, copyinTypes)) 274 return failure(); 275 segments[copyinClausePos] = copyins.size(); 276 } else if (keyword == "allocate") { 277 // Fail if there was already another allocate clause. 278 if (segments[allocateClausePos]) 279 return allowedOnce(parser, "allocate", opName); 280 if (parseAllocateAndAllocator(parser, allocates, allocateTypes, 281 allocators, allocatorTypes)) 282 return failure(); 283 segments[allocateClausePos] = allocates.size(); 284 segments[allocatorPos] = allocators.size(); 285 } else if (keyword == "default") { 286 // Fail if there was already another default clause. 287 if (defaultVal) 288 return allowedOnce(parser, "default", opName); 289 defaultVal = true; 290 StringRef defval; 291 if (parser.parseLParen() || parser.parseKeyword(&defval) || 292 parser.parseRParen()) 293 return failure(); 294 SmallString<16> attrval; 295 // The def prefix is required for the attribute as "private" is a keyword 296 // in C++. 297 attrval += "def"; 298 attrval += defval; 299 auto attr = parser.getBuilder().getStringAttr(attrval); 300 result.addAttribute("default_val", attr); 301 } else if (keyword == "proc_bind") { 302 // Fail if there was already another proc_bind clause. 303 if (procBind) 304 return allowedOnce(parser, "proc_bind", opName); 305 procBind = true; 306 StringRef bind; 307 if (parser.parseLParen() || parser.parseKeyword(&bind) || 308 parser.parseRParen()) 309 return failure(); 310 auto attr = parser.getBuilder().getStringAttr(bind); 311 result.addAttribute("proc_bind_val", attr); 312 } else { 313 return parser.emitError(parser.getNameLoc()) 314 << keyword << " is not a valid clause for the " << opName 315 << " operation"; 316 } 317 } 318 319 // Add if parameter. 320 if (segments[ifClausePos] && 321 parser.resolveOperand(ifCond.first, ifCond.second, result.operands)) 322 return failure(); 323 324 // Add num_threads parameter. 325 if (segments[numThreadsClausePos] && 326 parser.resolveOperand(numThreads.first, numThreads.second, 327 result.operands)) 328 return failure(); 329 330 // Add private parameters. 331 if (segments[privateClausePos] && 332 parser.resolveOperands(privates, privateTypes, privates[0].location, 333 result.operands)) 334 return failure(); 335 336 // Add firstprivate parameters. 337 if (segments[firstprivateClausePos] && 338 parser.resolveOperands(firstprivates, firstprivateTypes, 339 firstprivates[0].location, result.operands)) 340 return failure(); 341 342 // Add shared parameters. 343 if (segments[sharedClausePos] && 344 parser.resolveOperands(shareds, sharedTypes, shareds[0].location, 345 result.operands)) 346 return failure(); 347 348 // Add copyin parameters. 349 if (segments[copyinClausePos] && 350 parser.resolveOperands(copyins, copyinTypes, copyins[0].location, 351 result.operands)) 352 return failure(); 353 354 // Add allocate parameters. 355 if (segments[allocateClausePos] && 356 parser.resolveOperands(allocates, allocateTypes, allocates[0].location, 357 result.operands)) 358 return failure(); 359 360 // Add allocator parameters. 361 if (segments[allocatorPos] && 362 parser.resolveOperands(allocators, allocatorTypes, allocators[0].location, 363 result.operands)) 364 return failure(); 365 366 result.addAttribute("operand_segment_sizes", 367 parser.getBuilder().getI32VectorAttr(segments)); 368 369 Region *body = result.addRegion(); 370 SmallVector<OpAsmParser::OperandType, 4> regionArgs; 371 SmallVector<Type, 4> regionArgTypes; 372 if (parser.parseRegion(*body, regionArgs, regionArgTypes)) 373 return failure(); 374 return success(); 375 } 376 377 /// linear ::= `linear` `(` linear-list `)` 378 /// linear-list := linear-val | linear-val linear-list 379 /// linear-val := ssa-id-and-type `=` ssa-id-and-type 380 static ParseResult 381 parseLinearClause(OpAsmParser &parser, 382 SmallVectorImpl<OpAsmParser::OperandType> &vars, 383 SmallVectorImpl<Type> &types, 384 SmallVectorImpl<OpAsmParser::OperandType> &stepVars) { 385 if (parser.parseLParen()) 386 return failure(); 387 388 do { 389 OpAsmParser::OperandType var; 390 Type type; 391 OpAsmParser::OperandType stepVar; 392 if (parser.parseOperand(var) || parser.parseEqual() || 393 parser.parseOperand(stepVar) || parser.parseColonType(type)) 394 return failure(); 395 396 vars.push_back(var); 397 types.push_back(type); 398 stepVars.push_back(stepVar); 399 } while (succeeded(parser.parseOptionalComma())); 400 401 if (parser.parseRParen()) 402 return failure(); 403 404 return success(); 405 } 406 407 /// schedule ::= `schedule` `(` sched-list `)` 408 /// sched-list ::= sched-val | sched-val sched-list 409 /// sched-val ::= sched-with-chunk | sched-wo-chunk 410 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)? 411 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided` 412 /// sched-wo-chunk ::= `auto` | `runtime` 413 static ParseResult 414 parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule, 415 Optional<OpAsmParser::OperandType> &chunkSize) { 416 if (parser.parseLParen()) 417 return failure(); 418 419 StringRef keyword; 420 if (parser.parseKeyword(&keyword)) 421 return failure(); 422 423 schedule = keyword; 424 if (keyword == "static" || keyword == "dynamic" || keyword == "guided") { 425 if (succeeded(parser.parseOptionalEqual())) { 426 chunkSize = OpAsmParser::OperandType{}; 427 if (parser.parseOperand(*chunkSize)) 428 return failure(); 429 } else { 430 chunkSize = llvm::NoneType::None; 431 } 432 } else if (keyword == "auto" || keyword == "runtime") { 433 chunkSize = llvm::NoneType::None; 434 } else { 435 return parser.emitError(parser.getNameLoc()) << " expected schedule kind"; 436 } 437 438 if (parser.parseRParen()) 439 return failure(); 440 441 return success(); 442 } 443 444 /// Parses an OpenMP Workshare Loop operation 445 /// 446 /// operation ::= `omp.wsloop` loop-control clause-list 447 /// loop-control ::= `(` ssa-id-list `)` `:` type `=` loop-bounds 448 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` steps 449 /// steps := `step` `(`ssa-id-list`)` 450 /// clause-list ::= clause | empty | clause-list 451 /// clause ::= private | firstprivate | lastprivate | linear | schedule | 452 // collapse | nowait | ordered | order | inclusive 453 /// private ::= `private` `(` ssa-id-and-type-list `)` 454 /// firstprivate ::= `firstprivate` `(` ssa-id-and-type-list `)` 455 /// lastprivate ::= `lastprivate` `(` ssa-id-and-type-list `)` 456 /// linear ::= `linear` `(` linear-list `)` 457 /// schedule ::= `schedule` `(` sched-list `)` 458 /// collapse ::= `collapse` `(` ssa-id-and-type `)` 459 /// nowait ::= `nowait` 460 /// ordered ::= `ordered` `(` ssa-id-and-type `)` 461 /// order ::= `order` `(` `concurrent` `)` 462 /// inclusive ::= `inclusive` 463 /// 464 static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) { 465 Type loopVarType; 466 int numIVs; 467 468 // Parse an opening `(` followed by induction variables followed by `)` 469 SmallVector<OpAsmParser::OperandType> ivs; 470 if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1, 471 OpAsmParser::Delimiter::Paren)) 472 return failure(); 473 474 numIVs = static_cast<int>(ivs.size()); 475 476 if (parser.parseColonType(loopVarType)) 477 return failure(); 478 479 // Parse loop bounds. 480 SmallVector<OpAsmParser::OperandType> lower; 481 if (parser.parseEqual() || 482 parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) || 483 parser.resolveOperands(lower, loopVarType, result.operands)) 484 return failure(); 485 486 SmallVector<OpAsmParser::OperandType> upper; 487 if (parser.parseKeyword("to") || 488 parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) || 489 parser.resolveOperands(upper, loopVarType, result.operands)) 490 return failure(); 491 492 // Parse step values. 493 SmallVector<OpAsmParser::OperandType> steps; 494 if (parser.parseKeyword("step") || 495 parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) || 496 parser.resolveOperands(steps, loopVarType, result.operands)) 497 return failure(); 498 499 SmallVector<OpAsmParser::OperandType> privates; 500 SmallVector<Type> privateTypes; 501 SmallVector<OpAsmParser::OperandType> firstprivates; 502 SmallVector<Type> firstprivateTypes; 503 SmallVector<OpAsmParser::OperandType> lastprivates; 504 SmallVector<Type> lastprivateTypes; 505 SmallVector<OpAsmParser::OperandType> linears; 506 SmallVector<Type> linearTypes; 507 SmallVector<OpAsmParser::OperandType> linearSteps; 508 SmallString<8> schedule; 509 Optional<OpAsmParser::OperandType> scheduleChunkSize; 510 std::array<int, 9> segments{numIVs, numIVs, numIVs, 0, 0, 0, 0, 0, 0}; 511 512 const StringRef opName = result.name.getStringRef(); 513 StringRef keyword; 514 515 enum SegmentPos { 516 lbPos = 0, 517 ubPos, 518 stepPos, 519 privateClausePos, 520 firstprivateClausePos, 521 lastprivateClausePos, 522 linearClausePos, 523 linearStepPos, 524 scheduleClausePos, 525 }; 526 527 while (succeeded(parser.parseOptionalKeyword(&keyword))) { 528 if (keyword == "private") { 529 if (segments[privateClausePos]) 530 return allowedOnce(parser, "private", opName); 531 if (parseOperandAndTypeList(parser, privates, privateTypes)) 532 return failure(); 533 segments[privateClausePos] = privates.size(); 534 } else if (keyword == "firstprivate") { 535 // fail if there was already another firstprivate clause 536 if (segments[firstprivateClausePos]) 537 return allowedOnce(parser, "firstprivate", opName); 538 if (parseOperandAndTypeList(parser, firstprivates, firstprivateTypes)) 539 return failure(); 540 segments[firstprivateClausePos] = firstprivates.size(); 541 } else if (keyword == "lastprivate") { 542 // fail if there was already another shared clause 543 if (segments[lastprivateClausePos]) 544 return allowedOnce(parser, "lastprivate", opName); 545 if (parseOperandAndTypeList(parser, lastprivates, lastprivateTypes)) 546 return failure(); 547 segments[lastprivateClausePos] = lastprivates.size(); 548 } else if (keyword == "linear") { 549 // fail if there was already another linear clause 550 if (segments[linearClausePos]) 551 return allowedOnce(parser, "linear", opName); 552 if (parseLinearClause(parser, linears, linearTypes, linearSteps)) 553 return failure(); 554 segments[linearClausePos] = linears.size(); 555 segments[linearStepPos] = linearSteps.size(); 556 } else if (keyword == "schedule") { 557 if (!schedule.empty()) 558 return allowedOnce(parser, "schedule", opName); 559 if (parseScheduleClause(parser, schedule, scheduleChunkSize)) 560 return failure(); 561 if (scheduleChunkSize) { 562 segments[scheduleClausePos] = 1; 563 } 564 } else if (keyword == "collapse") { 565 auto type = parser.getBuilder().getI64Type(); 566 mlir::IntegerAttr attr; 567 if (parser.parseLParen() || parser.parseAttribute(attr, type) || 568 parser.parseRParen()) 569 return failure(); 570 result.addAttribute("collapse_val", attr); 571 } else if (keyword == "nowait") { 572 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 573 result.addAttribute("nowait", attr); 574 } else if (keyword == "ordered") { 575 mlir::IntegerAttr attr; 576 if (succeeded(parser.parseOptionalLParen())) { 577 auto type = parser.getBuilder().getI64Type(); 578 if (parser.parseAttribute(attr, type)) 579 return failure(); 580 if (parser.parseRParen()) 581 return failure(); 582 } else { 583 // Use 0 to represent no ordered parameter was specified 584 attr = parser.getBuilder().getI64IntegerAttr(0); 585 } 586 result.addAttribute("ordered_val", attr); 587 } else if (keyword == "order") { 588 StringRef order; 589 if (parser.parseLParen() || parser.parseKeyword(&order) || 590 parser.parseRParen()) 591 return failure(); 592 auto attr = parser.getBuilder().getStringAttr(order); 593 result.addAttribute("order", attr); 594 } else if (keyword == "inclusive") { 595 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 596 result.addAttribute("inclusive", attr); 597 } 598 } 599 600 if (segments[privateClausePos]) { 601 parser.resolveOperands(privates, privateTypes, privates[0].location, 602 result.operands); 603 } 604 605 if (segments[firstprivateClausePos]) { 606 parser.resolveOperands(firstprivates, firstprivateTypes, 607 firstprivates[0].location, result.operands); 608 } 609 610 if (segments[lastprivateClausePos]) { 611 parser.resolveOperands(lastprivates, lastprivateTypes, 612 lastprivates[0].location, result.operands); 613 } 614 615 if (segments[linearClausePos]) { 616 parser.resolveOperands(linears, linearTypes, linears[0].location, 617 result.operands); 618 auto linearStepType = parser.getBuilder().getI32Type(); 619 SmallVector<Type> linearStepTypes(linearSteps.size(), linearStepType); 620 parser.resolveOperands(linearSteps, linearStepTypes, 621 linearSteps[0].location, result.operands); 622 } 623 624 if (!schedule.empty()) { 625 schedule[0] = llvm::toUpper(schedule[0]); 626 auto attr = parser.getBuilder().getStringAttr(schedule); 627 result.addAttribute("schedule_val", attr); 628 if (scheduleChunkSize) { 629 auto chunkSizeType = parser.getBuilder().getI32Type(); 630 parser.resolveOperand(*scheduleChunkSize, chunkSizeType, result.operands); 631 } 632 } 633 634 result.addAttribute("operand_segment_sizes", 635 parser.getBuilder().getI32VectorAttr(segments)); 636 637 // Now parse the body. 638 Region *body = result.addRegion(); 639 SmallVector<Type> ivTypes(numIVs, loopVarType); 640 if (parser.parseRegion(*body, ivs, ivTypes)) 641 return failure(); 642 return success(); 643 } 644 645 static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) { 646 auto args = op.getRegion().front().getArguments(); 647 p << op.getOperationName() << " (" << args << ") : " << args[0].getType() 648 << " = (" << op.lowerBound() << ") to (" << op.upperBound() << ") step (" 649 << op.step() << ")"; 650 651 // Print private, firstprivate, shared and copyin parameters 652 auto printDataVars = [&p](StringRef name, OperandRange vars) { 653 if (vars.empty()) 654 return; 655 656 p << " " << name << "("; 657 llvm::interleaveComma( 658 vars, p, [&](const Value &v) { p << v << " : " << v.getType(); }); 659 p << ")"; 660 }; 661 printDataVars("private", op.private_vars()); 662 printDataVars("firstprivate", op.firstprivate_vars()); 663 printDataVars("lastprivate", op.lastprivate_vars()); 664 665 auto linearVars = op.linear_vars(); 666 auto linearVarsSize = linearVars.size(); 667 if (linearVarsSize) { 668 p << " " 669 << "linear" 670 << "("; 671 for (unsigned i = 0; i < linearVarsSize; ++i) { 672 std::string separator = i == linearVarsSize - 1 ? ")" : ", "; 673 p << linearVars[i]; 674 if (op.linear_step_vars().size() > i) 675 p << " = " << op.linear_step_vars()[i]; 676 p << " : " << linearVars[i].getType() << separator; 677 } 678 } 679 680 if (auto sched = op.schedule_val()) { 681 auto schedLower = sched->lower(); 682 p << " schedule(" << schedLower; 683 if (auto chunk = op.schedule_chunk_var()) { 684 p << " = " << chunk; 685 } 686 p << ")"; 687 } 688 689 if (auto collapse = op.collapse_val()) 690 p << " collapse(" << collapse << ")"; 691 692 if (op.nowait()) 693 p << " nowait"; 694 695 if (auto ordered = op.ordered_val()) { 696 p << " ordered(" << ordered << ")"; 697 } 698 699 if (op.inclusive()) { 700 p << " inclusive"; 701 } 702 703 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 704 } 705 706 //===----------------------------------------------------------------------===// 707 // WsLoopOp 708 //===----------------------------------------------------------------------===// 709 710 void WsLoopOp::build(OpBuilder &builder, OperationState &state, 711 ValueRange lowerBound, ValueRange upperBound, 712 ValueRange step, ArrayRef<NamedAttribute> attributes) { 713 build(builder, state, TypeRange(), lowerBound, upperBound, step, 714 /*private_vars=*/ValueRange(), 715 /*firstprivate_vars=*/ValueRange(), /*lastprivate_vars=*/ValueRange(), 716 /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(), 717 /*schedule_val=*/nullptr, /*schedule_chunk_var=*/nullptr, 718 /*collapse_val=*/nullptr, 719 /*nowait=*/nullptr, /*ordered_val=*/nullptr, /*order_val=*/nullptr, 720 /*inclusive=*/nullptr, /*buildBody=*/false); 721 state.addAttributes(attributes); 722 } 723 724 void WsLoopOp::build(OpBuilder &, OperationState &state, TypeRange resultTypes, 725 ValueRange operands, ArrayRef<NamedAttribute> attributes) { 726 state.addOperands(operands); 727 state.addAttributes(attributes); 728 (void)state.addRegion(); 729 assert(resultTypes.size() == 0u && "mismatched number of return types"); 730 state.addTypes(resultTypes); 731 } 732 733 void WsLoopOp::build(OpBuilder &builder, OperationState &result, 734 TypeRange typeRange, ValueRange lowerBounds, 735 ValueRange upperBounds, ValueRange steps, 736 ValueRange privateVars, ValueRange firstprivateVars, 737 ValueRange lastprivateVars, ValueRange linearVars, 738 ValueRange linearStepVars, StringAttr scheduleVal, 739 Value scheduleChunkVar, IntegerAttr collapseVal, 740 UnitAttr nowait, IntegerAttr orderedVal, 741 StringAttr orderVal, UnitAttr inclusive, bool buildBody) { 742 result.addOperands(lowerBounds); 743 result.addOperands(upperBounds); 744 result.addOperands(steps); 745 result.addOperands(privateVars); 746 result.addOperands(firstprivateVars); 747 result.addOperands(linearVars); 748 result.addOperands(linearStepVars); 749 if (scheduleChunkVar) 750 result.addOperands(scheduleChunkVar); 751 752 if (scheduleVal) 753 result.addAttribute("schedule_val", scheduleVal); 754 if (collapseVal) 755 result.addAttribute("collapse_val", collapseVal); 756 if (nowait) 757 result.addAttribute("nowait", nowait); 758 if (orderedVal) 759 result.addAttribute("ordered_val", orderedVal); 760 if (orderVal) 761 result.addAttribute("order", orderVal); 762 if (inclusive) 763 result.addAttribute("inclusive", inclusive); 764 result.addAttribute( 765 WsLoopOp::getOperandSegmentSizeAttr(), 766 builder.getI32VectorAttr( 767 {static_cast<int32_t>(lowerBounds.size()), 768 static_cast<int32_t>(upperBounds.size()), 769 static_cast<int32_t>(steps.size()), 770 static_cast<int32_t>(privateVars.size()), 771 static_cast<int32_t>(firstprivateVars.size()), 772 static_cast<int32_t>(lastprivateVars.size()), 773 static_cast<int32_t>(linearVars.size()), 774 static_cast<int32_t>(linearStepVars.size()), 775 static_cast<int32_t>(scheduleChunkVar != nullptr ? 1 : 0)})); 776 777 Region *bodyRegion = result.addRegion(); 778 if (buildBody) { 779 OpBuilder::InsertionGuard guard(builder); 780 unsigned numIVs = steps.size(); 781 SmallVector<Type, 8> argTypes(numIVs, steps.getType().front()); 782 builder.createBlock(bodyRegion, {}, argTypes); 783 } 784 } 785 786 #define GET_OP_CLASSES 787 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 788