1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the OpenMP dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/IR/Attributes.h" 16 #include "mlir/IR/OpImplementation.h" 17 #include "mlir/IR/OperationSupport.h" 18 19 #include "llvm/ADT/SmallString.h" 20 #include "llvm/ADT/StringExtras.h" 21 #include "llvm/ADT/StringRef.h" 22 #include "llvm/ADT/StringSwitch.h" 23 #include <cstddef> 24 25 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc" 26 27 using namespace mlir; 28 using namespace mlir::omp; 29 30 void OpenMPDialect::initialize() { 31 addOperations< 32 #define GET_OP_LIST 33 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 34 >(); 35 } 36 37 //===----------------------------------------------------------------------===// 38 // ParallelOp 39 //===----------------------------------------------------------------------===// 40 41 void ParallelOp::build(OpBuilder &builder, OperationState &state, 42 ArrayRef<NamedAttribute> attributes) { 43 ParallelOp::build( 44 builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr, 45 /*default_val=*/nullptr, /*private_vars=*/ValueRange(), 46 /*firstprivate_vars=*/ValueRange(), /*shared_vars=*/ValueRange(), 47 /*copyin_vars=*/ValueRange(), /*allocate_vars=*/ValueRange(), 48 /*allocators_vars=*/ValueRange(), /*proc_bind_val=*/nullptr); 49 state.addAttributes(attributes); 50 } 51 52 /// Parse a list of operands with types. 53 /// 54 /// operand-and-type-list ::= `(` ssa-id-and-type-list `)` 55 /// ssa-id-and-type-list ::= ssa-id-and-type | 56 /// ssa-id-and-type `,` ssa-id-and-type-list 57 /// ssa-id-and-type ::= ssa-id `:` type 58 static ParseResult 59 parseOperandAndTypeList(OpAsmParser &parser, 60 SmallVectorImpl<OpAsmParser::OperandType> &operands, 61 SmallVectorImpl<Type> &types) { 62 if (parser.parseLParen()) 63 return failure(); 64 65 do { 66 OpAsmParser::OperandType operand; 67 Type type; 68 if (parser.parseOperand(operand) || parser.parseColonType(type)) 69 return failure(); 70 operands.push_back(operand); 71 types.push_back(type); 72 } while (succeeded(parser.parseOptionalComma())); 73 74 if (parser.parseRParen()) 75 return failure(); 76 77 return success(); 78 } 79 80 /// Parse an allocate clause with allocators and a list of operands with types. 81 /// 82 /// operand-and-type-list ::= `(` allocate-operand-list `)` 83 /// allocate-operand-list :: = allocate-operand | 84 /// allocator-operand `,` allocate-operand-list 85 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type 86 /// ssa-id-and-type ::= ssa-id `:` type 87 static ParseResult parseAllocateAndAllocator( 88 OpAsmParser &parser, 89 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate, 90 SmallVectorImpl<Type> &typesAllocate, 91 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator, 92 SmallVectorImpl<Type> &typesAllocator) { 93 if (parser.parseLParen()) 94 return failure(); 95 96 do { 97 OpAsmParser::OperandType operand; 98 Type type; 99 100 if (parser.parseOperand(operand) || parser.parseColonType(type)) 101 return failure(); 102 operandsAllocator.push_back(operand); 103 typesAllocator.push_back(type); 104 if (parser.parseArrow()) 105 return failure(); 106 if (parser.parseOperand(operand) || parser.parseColonType(type)) 107 return failure(); 108 109 operandsAllocate.push_back(operand); 110 typesAllocate.push_back(type); 111 } while (succeeded(parser.parseOptionalComma())); 112 113 if (parser.parseRParen()) 114 return failure(); 115 116 return success(); 117 } 118 119 static LogicalResult verifyParallelOp(ParallelOp op) { 120 if (op.allocate_vars().size() != op.allocators_vars().size()) 121 return op.emitError( 122 "expected equal sizes for allocate and allocator variables"); 123 return success(); 124 } 125 126 static void printParallelOp(OpAsmPrinter &p, ParallelOp op) { 127 p << "omp.parallel"; 128 129 if (auto ifCond = op.if_expr_var()) 130 p << " if(" << ifCond << " : " << ifCond.getType() << ")"; 131 132 if (auto threads = op.num_threads_var()) 133 p << " num_threads(" << threads << " : " << threads.getType() << ")"; 134 135 // Print private, firstprivate, shared and copyin parameters 136 auto printDataVars = [&p](StringRef name, OperandRange vars) { 137 if (vars.size()) { 138 p << " " << name << "("; 139 for (unsigned i = 0; i < vars.size(); ++i) { 140 std::string separator = i == vars.size() - 1 ? ")" : ", "; 141 p << vars[i] << " : " << vars[i].getType() << separator; 142 } 143 } 144 }; 145 146 // Print allocator and allocate parameters 147 auto printAllocateAndAllocator = [&p](OperandRange varsAllocate, 148 OperandRange varsAllocator) { 149 if (varsAllocate.empty()) 150 return; 151 152 p << " allocate("; 153 for (unsigned i = 0; i < varsAllocate.size(); ++i) { 154 std::string separator = i == varsAllocate.size() - 1 ? ")" : ", "; 155 p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> "; 156 p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator; 157 } 158 }; 159 160 printDataVars("private", op.private_vars()); 161 printDataVars("firstprivate", op.firstprivate_vars()); 162 printDataVars("shared", op.shared_vars()); 163 printDataVars("copyin", op.copyin_vars()); 164 printAllocateAndAllocator(op.allocate_vars(), op.allocators_vars()); 165 166 if (auto def = op.default_val()) 167 p << " default(" << def->drop_front(3) << ")"; 168 169 if (auto bind = op.proc_bind_val()) 170 p << " proc_bind(" << bind << ")"; 171 172 p.printRegion(op.getRegion()); 173 } 174 175 /// Emit an error if the same clause is present more than once on an operation. 176 static ParseResult allowedOnce(OpAsmParser &parser, StringRef clause, 177 StringRef operation) { 178 return parser.emitError(parser.getNameLoc()) 179 << " at most one " << clause << " clause can appear on the " 180 << operation << " operation"; 181 } 182 183 /// Parses a parallel operation. 184 /// 185 /// operation ::= `omp.parallel` clause-list 186 /// clause-list ::= clause | clause clause-list 187 /// clause ::= if | numThreads | private | firstprivate | shared | copyin | 188 /// default | procBind 189 /// if ::= `if` `(` ssa-id `)` 190 /// numThreads ::= `num_threads` `(` ssa-id-and-type `)` 191 /// private ::= `private` operand-and-type-list 192 /// firstprivate ::= `firstprivate` operand-and-type-list 193 /// shared ::= `shared` operand-and-type-list 194 /// copyin ::= `copyin` operand-and-type-list 195 /// allocate ::= `allocate` operand-and-type `->` operand-and-type-list 196 /// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`) 197 /// procBind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)` 198 /// 199 /// Note that each clause can only appear once in the clase-list. 200 static ParseResult parseParallelOp(OpAsmParser &parser, 201 OperationState &result) { 202 std::pair<OpAsmParser::OperandType, Type> ifCond; 203 std::pair<OpAsmParser::OperandType, Type> numThreads; 204 SmallVector<OpAsmParser::OperandType, 4> privates; 205 SmallVector<Type, 4> privateTypes; 206 SmallVector<OpAsmParser::OperandType, 4> firstprivates; 207 SmallVector<Type, 4> firstprivateTypes; 208 SmallVector<OpAsmParser::OperandType, 4> shareds; 209 SmallVector<Type, 4> sharedTypes; 210 SmallVector<OpAsmParser::OperandType, 4> copyins; 211 SmallVector<Type, 4> copyinTypes; 212 SmallVector<OpAsmParser::OperandType, 4> allocates; 213 SmallVector<Type, 4> allocateTypes; 214 SmallVector<OpAsmParser::OperandType, 4> allocators; 215 SmallVector<Type, 4> allocatorTypes; 216 std::array<int, 8> segments{0, 0, 0, 0, 0, 0, 0, 0}; 217 StringRef keyword; 218 bool defaultVal = false; 219 bool procBind = false; 220 221 const int ifClausePos = 0; 222 const int numThreadsClausePos = 1; 223 const int privateClausePos = 2; 224 const int firstprivateClausePos = 3; 225 const int sharedClausePos = 4; 226 const int copyinClausePos = 5; 227 const int allocateClausePos = 6; 228 const int allocatorPos = 7; 229 const StringRef opName = result.name.getStringRef(); 230 231 while (succeeded(parser.parseOptionalKeyword(&keyword))) { 232 if (keyword == "if") { 233 // Fail if there was already another if condition. 234 if (segments[ifClausePos]) 235 return allowedOnce(parser, "if", opName); 236 if (parser.parseLParen() || parser.parseOperand(ifCond.first) || 237 parser.parseColonType(ifCond.second) || parser.parseRParen()) 238 return failure(); 239 segments[ifClausePos] = 1; 240 } else if (keyword == "num_threads") { 241 // Fail if there was already another num_threads clause. 242 if (segments[numThreadsClausePos]) 243 return allowedOnce(parser, "num_threads", opName); 244 if (parser.parseLParen() || parser.parseOperand(numThreads.first) || 245 parser.parseColonType(numThreads.second) || parser.parseRParen()) 246 return failure(); 247 segments[numThreadsClausePos] = 1; 248 } else if (keyword == "private") { 249 // Fail if there was already another private clause. 250 if (segments[privateClausePos]) 251 return allowedOnce(parser, "private", opName); 252 if (parseOperandAndTypeList(parser, privates, privateTypes)) 253 return failure(); 254 segments[privateClausePos] = privates.size(); 255 } else if (keyword == "firstprivate") { 256 // Fail if there was already another firstprivate clause. 257 if (segments[firstprivateClausePos]) 258 return allowedOnce(parser, "firstprivate", opName); 259 if (parseOperandAndTypeList(parser, firstprivates, firstprivateTypes)) 260 return failure(); 261 segments[firstprivateClausePos] = firstprivates.size(); 262 } else if (keyword == "shared") { 263 // Fail if there was already another shared clause. 264 if (segments[sharedClausePos]) 265 return allowedOnce(parser, "shared", opName); 266 if (parseOperandAndTypeList(parser, shareds, sharedTypes)) 267 return failure(); 268 segments[sharedClausePos] = shareds.size(); 269 } else if (keyword == "copyin") { 270 // Fail if there was already another copyin clause. 271 if (segments[copyinClausePos]) 272 return allowedOnce(parser, "copyin", opName); 273 if (parseOperandAndTypeList(parser, copyins, copyinTypes)) 274 return failure(); 275 segments[copyinClausePos] = copyins.size(); 276 } else if (keyword == "allocate") { 277 // Fail if there was already another allocate clause. 278 if (segments[allocateClausePos]) 279 return allowedOnce(parser, "allocate", opName); 280 if (parseAllocateAndAllocator(parser, allocates, allocateTypes, 281 allocators, allocatorTypes)) 282 return failure(); 283 segments[allocateClausePos] = allocates.size(); 284 segments[allocatorPos] = allocators.size(); 285 } else if (keyword == "default") { 286 // Fail if there was already another default clause. 287 if (defaultVal) 288 return allowedOnce(parser, "default", opName); 289 defaultVal = true; 290 StringRef defval; 291 if (parser.parseLParen() || parser.parseKeyword(&defval) || 292 parser.parseRParen()) 293 return failure(); 294 // The def prefix is required for the attribute as "private" is a keyword 295 // in C++. 296 auto attr = parser.getBuilder().getStringAttr("def" + defval); 297 result.addAttribute("default_val", attr); 298 } else if (keyword == "proc_bind") { 299 // Fail if there was already another proc_bind clause. 300 if (procBind) 301 return allowedOnce(parser, "proc_bind", opName); 302 procBind = true; 303 StringRef bind; 304 if (parser.parseLParen() || parser.parseKeyword(&bind) || 305 parser.parseRParen()) 306 return failure(); 307 auto attr = parser.getBuilder().getStringAttr(bind); 308 result.addAttribute("proc_bind_val", attr); 309 } else { 310 return parser.emitError(parser.getNameLoc()) 311 << keyword << " is not a valid clause for the " << opName 312 << " operation"; 313 } 314 } 315 316 // Add if parameter. 317 if (segments[ifClausePos] && 318 parser.resolveOperand(ifCond.first, ifCond.second, result.operands)) 319 return failure(); 320 321 // Add num_threads parameter. 322 if (segments[numThreadsClausePos] && 323 parser.resolveOperand(numThreads.first, numThreads.second, 324 result.operands)) 325 return failure(); 326 327 // Add private parameters. 328 if (segments[privateClausePos] && 329 parser.resolveOperands(privates, privateTypes, privates[0].location, 330 result.operands)) 331 return failure(); 332 333 // Add firstprivate parameters. 334 if (segments[firstprivateClausePos] && 335 parser.resolveOperands(firstprivates, firstprivateTypes, 336 firstprivates[0].location, result.operands)) 337 return failure(); 338 339 // Add shared parameters. 340 if (segments[sharedClausePos] && 341 parser.resolveOperands(shareds, sharedTypes, shareds[0].location, 342 result.operands)) 343 return failure(); 344 345 // Add copyin parameters. 346 if (segments[copyinClausePos] && 347 parser.resolveOperands(copyins, copyinTypes, copyins[0].location, 348 result.operands)) 349 return failure(); 350 351 // Add allocate parameters. 352 if (segments[allocateClausePos] && 353 parser.resolveOperands(allocates, allocateTypes, allocates[0].location, 354 result.operands)) 355 return failure(); 356 357 // Add allocator parameters. 358 if (segments[allocatorPos] && 359 parser.resolveOperands(allocators, allocatorTypes, allocators[0].location, 360 result.operands)) 361 return failure(); 362 363 result.addAttribute("operand_segment_sizes", 364 parser.getBuilder().getI32VectorAttr(segments)); 365 366 Region *body = result.addRegion(); 367 SmallVector<OpAsmParser::OperandType, 4> regionArgs; 368 SmallVector<Type, 4> regionArgTypes; 369 if (parser.parseRegion(*body, regionArgs, regionArgTypes)) 370 return failure(); 371 return success(); 372 } 373 374 /// linear ::= `linear` `(` linear-list `)` 375 /// linear-list := linear-val | linear-val linear-list 376 /// linear-val := ssa-id-and-type `=` ssa-id-and-type 377 static ParseResult 378 parseLinearClause(OpAsmParser &parser, 379 SmallVectorImpl<OpAsmParser::OperandType> &vars, 380 SmallVectorImpl<Type> &types, 381 SmallVectorImpl<OpAsmParser::OperandType> &stepVars) { 382 if (parser.parseLParen()) 383 return failure(); 384 385 do { 386 OpAsmParser::OperandType var; 387 Type type; 388 OpAsmParser::OperandType stepVar; 389 if (parser.parseOperand(var) || parser.parseEqual() || 390 parser.parseOperand(stepVar) || parser.parseColonType(type)) 391 return failure(); 392 393 vars.push_back(var); 394 types.push_back(type); 395 stepVars.push_back(stepVar); 396 } while (succeeded(parser.parseOptionalComma())); 397 398 if (parser.parseRParen()) 399 return failure(); 400 401 return success(); 402 } 403 404 /// schedule ::= `schedule` `(` sched-list `)` 405 /// sched-list ::= sched-val | sched-val sched-list 406 /// sched-val ::= sched-with-chunk | sched-wo-chunk 407 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)? 408 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided` 409 /// sched-wo-chunk ::= `auto` | `runtime` 410 static ParseResult 411 parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule, 412 Optional<OpAsmParser::OperandType> &chunkSize) { 413 if (parser.parseLParen()) 414 return failure(); 415 416 StringRef keyword; 417 if (parser.parseKeyword(&keyword)) 418 return failure(); 419 420 schedule = keyword; 421 if (keyword == "static" || keyword == "dynamic" || keyword == "guided") { 422 if (succeeded(parser.parseOptionalEqual())) { 423 chunkSize = OpAsmParser::OperandType{}; 424 if (parser.parseOperand(*chunkSize)) 425 return failure(); 426 } else { 427 chunkSize = llvm::NoneType::None; 428 } 429 } else if (keyword == "auto" || keyword == "runtime") { 430 chunkSize = llvm::NoneType::None; 431 } else { 432 return parser.emitError(parser.getNameLoc()) << " expected schedule kind"; 433 } 434 435 if (parser.parseRParen()) 436 return failure(); 437 438 return success(); 439 } 440 441 /// Parses an OpenMP Workshare Loop operation 442 /// 443 /// operation ::= `omp.wsloop` loop-control clause-list 444 /// loop-control ::= `(` ssa-id-list `)` `:` type `=` loop-bounds 445 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` steps 446 /// steps := `step` `(`ssa-id-list`)` 447 /// clause-list ::= clause | empty | clause-list 448 /// clause ::= private | firstprivate | lastprivate | linear | schedule | 449 // collapse | nowait | ordered | order | inclusive 450 /// private ::= `private` `(` ssa-id-and-type-list `)` 451 /// firstprivate ::= `firstprivate` `(` ssa-id-and-type-list `)` 452 /// lastprivate ::= `lastprivate` `(` ssa-id-and-type-list `)` 453 /// linear ::= `linear` `(` linear-list `)` 454 /// schedule ::= `schedule` `(` sched-list `)` 455 /// collapse ::= `collapse` `(` ssa-id-and-type `)` 456 /// nowait ::= `nowait` 457 /// ordered ::= `ordered` `(` ssa-id-and-type `)` 458 /// order ::= `order` `(` `concurrent` `)` 459 /// inclusive ::= `inclusive` 460 /// 461 static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) { 462 Type loopVarType; 463 int numIVs; 464 465 // Parse an opening `(` followed by induction variables followed by `)` 466 SmallVector<OpAsmParser::OperandType> ivs; 467 if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1, 468 OpAsmParser::Delimiter::Paren)) 469 return failure(); 470 471 numIVs = static_cast<int>(ivs.size()); 472 473 if (parser.parseColonType(loopVarType)) 474 return failure(); 475 476 // Parse loop bounds. 477 SmallVector<OpAsmParser::OperandType> lower; 478 if (parser.parseEqual() || 479 parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) || 480 parser.resolveOperands(lower, loopVarType, result.operands)) 481 return failure(); 482 483 SmallVector<OpAsmParser::OperandType> upper; 484 if (parser.parseKeyword("to") || 485 parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) || 486 parser.resolveOperands(upper, loopVarType, result.operands)) 487 return failure(); 488 489 // Parse step values. 490 SmallVector<OpAsmParser::OperandType> steps; 491 if (parser.parseKeyword("step") || 492 parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) || 493 parser.resolveOperands(steps, loopVarType, result.operands)) 494 return failure(); 495 496 SmallVector<OpAsmParser::OperandType> privates; 497 SmallVector<Type> privateTypes; 498 SmallVector<OpAsmParser::OperandType> firstprivates; 499 SmallVector<Type> firstprivateTypes; 500 SmallVector<OpAsmParser::OperandType> lastprivates; 501 SmallVector<Type> lastprivateTypes; 502 SmallVector<OpAsmParser::OperandType> linears; 503 SmallVector<Type> linearTypes; 504 SmallVector<OpAsmParser::OperandType> linearSteps; 505 SmallString<8> schedule; 506 Optional<OpAsmParser::OperandType> scheduleChunkSize; 507 std::array<int, 9> segments{numIVs, numIVs, numIVs, 0, 0, 0, 0, 0, 0}; 508 509 const StringRef opName = result.name.getStringRef(); 510 StringRef keyword; 511 512 enum SegmentPos { 513 lbPos = 0, 514 ubPos, 515 stepPos, 516 privateClausePos, 517 firstprivateClausePos, 518 lastprivateClausePos, 519 linearClausePos, 520 linearStepPos, 521 scheduleClausePos, 522 }; 523 524 while (succeeded(parser.parseOptionalKeyword(&keyword))) { 525 if (keyword == "private") { 526 if (segments[privateClausePos]) 527 return allowedOnce(parser, "private", opName); 528 if (parseOperandAndTypeList(parser, privates, privateTypes)) 529 return failure(); 530 segments[privateClausePos] = privates.size(); 531 } else if (keyword == "firstprivate") { 532 // fail if there was already another firstprivate clause 533 if (segments[firstprivateClausePos]) 534 return allowedOnce(parser, "firstprivate", opName); 535 if (parseOperandAndTypeList(parser, firstprivates, firstprivateTypes)) 536 return failure(); 537 segments[firstprivateClausePos] = firstprivates.size(); 538 } else if (keyword == "lastprivate") { 539 // fail if there was already another shared clause 540 if (segments[lastprivateClausePos]) 541 return allowedOnce(parser, "lastprivate", opName); 542 if (parseOperandAndTypeList(parser, lastprivates, lastprivateTypes)) 543 return failure(); 544 segments[lastprivateClausePos] = lastprivates.size(); 545 } else if (keyword == "linear") { 546 // fail if there was already another linear clause 547 if (segments[linearClausePos]) 548 return allowedOnce(parser, "linear", opName); 549 if (parseLinearClause(parser, linears, linearTypes, linearSteps)) 550 return failure(); 551 segments[linearClausePos] = linears.size(); 552 segments[linearStepPos] = linearSteps.size(); 553 } else if (keyword == "schedule") { 554 if (!schedule.empty()) 555 return allowedOnce(parser, "schedule", opName); 556 if (parseScheduleClause(parser, schedule, scheduleChunkSize)) 557 return failure(); 558 if (scheduleChunkSize) { 559 segments[scheduleClausePos] = 1; 560 } 561 } else if (keyword == "collapse") { 562 auto type = parser.getBuilder().getI64Type(); 563 mlir::IntegerAttr attr; 564 if (parser.parseLParen() || parser.parseAttribute(attr, type) || 565 parser.parseRParen()) 566 return failure(); 567 result.addAttribute("collapse_val", attr); 568 } else if (keyword == "nowait") { 569 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 570 result.addAttribute("nowait", attr); 571 } else if (keyword == "ordered") { 572 mlir::IntegerAttr attr; 573 if (succeeded(parser.parseOptionalLParen())) { 574 auto type = parser.getBuilder().getI64Type(); 575 if (parser.parseAttribute(attr, type)) 576 return failure(); 577 if (parser.parseRParen()) 578 return failure(); 579 } else { 580 // Use 0 to represent no ordered parameter was specified 581 attr = parser.getBuilder().getI64IntegerAttr(0); 582 } 583 result.addAttribute("ordered_val", attr); 584 } else if (keyword == "order") { 585 StringRef order; 586 if (parser.parseLParen() || parser.parseKeyword(&order) || 587 parser.parseRParen()) 588 return failure(); 589 auto attr = parser.getBuilder().getStringAttr(order); 590 result.addAttribute("order", attr); 591 } else if (keyword == "inclusive") { 592 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 593 result.addAttribute("inclusive", attr); 594 } 595 } 596 597 if (segments[privateClausePos]) { 598 parser.resolveOperands(privates, privateTypes, privates[0].location, 599 result.operands); 600 } 601 602 if (segments[firstprivateClausePos]) { 603 parser.resolveOperands(firstprivates, firstprivateTypes, 604 firstprivates[0].location, result.operands); 605 } 606 607 if (segments[lastprivateClausePos]) { 608 parser.resolveOperands(lastprivates, lastprivateTypes, 609 lastprivates[0].location, result.operands); 610 } 611 612 if (segments[linearClausePos]) { 613 parser.resolveOperands(linears, linearTypes, linears[0].location, 614 result.operands); 615 auto linearStepType = parser.getBuilder().getI32Type(); 616 SmallVector<Type> linearStepTypes(linearSteps.size(), linearStepType); 617 parser.resolveOperands(linearSteps, linearStepTypes, 618 linearSteps[0].location, result.operands); 619 } 620 621 if (!schedule.empty()) { 622 schedule[0] = llvm::toUpper(schedule[0]); 623 auto attr = parser.getBuilder().getStringAttr(schedule); 624 result.addAttribute("schedule_val", attr); 625 if (scheduleChunkSize) { 626 auto chunkSizeType = parser.getBuilder().getI32Type(); 627 parser.resolveOperand(*scheduleChunkSize, chunkSizeType, result.operands); 628 } 629 } 630 631 result.addAttribute("operand_segment_sizes", 632 parser.getBuilder().getI32VectorAttr(segments)); 633 634 // Now parse the body. 635 Region *body = result.addRegion(); 636 SmallVector<Type> ivTypes(numIVs, loopVarType); 637 if (parser.parseRegion(*body, ivs, ivTypes)) 638 return failure(); 639 return success(); 640 } 641 642 static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) { 643 auto args = op.getRegion().front().getArguments(); 644 p << op.getOperationName() << " (" << args << ") : " << args[0].getType() 645 << " = (" << op.lowerBound() << ") to (" << op.upperBound() << ") step (" 646 << op.step() << ")"; 647 648 // Print private, firstprivate, shared and copyin parameters 649 auto printDataVars = [&p](StringRef name, OperandRange vars) { 650 if (vars.empty()) 651 return; 652 653 p << " " << name << "("; 654 llvm::interleaveComma( 655 vars, p, [&](const Value &v) { p << v << " : " << v.getType(); }); 656 p << ")"; 657 }; 658 printDataVars("private", op.private_vars()); 659 printDataVars("firstprivate", op.firstprivate_vars()); 660 printDataVars("lastprivate", op.lastprivate_vars()); 661 662 auto linearVars = op.linear_vars(); 663 auto linearVarsSize = linearVars.size(); 664 if (linearVarsSize) { 665 p << " " 666 << "linear" 667 << "("; 668 for (unsigned i = 0; i < linearVarsSize; ++i) { 669 std::string separator = i == linearVarsSize - 1 ? ")" : ", "; 670 p << linearVars[i]; 671 if (op.linear_step_vars().size() > i) 672 p << " = " << op.linear_step_vars()[i]; 673 p << " : " << linearVars[i].getType() << separator; 674 } 675 } 676 677 if (auto sched = op.schedule_val()) { 678 auto schedLower = sched->lower(); 679 p << " schedule(" << schedLower; 680 if (auto chunk = op.schedule_chunk_var()) { 681 p << " = " << chunk; 682 } 683 p << ")"; 684 } 685 686 if (auto collapse = op.collapse_val()) 687 p << " collapse(" << collapse << ")"; 688 689 if (op.nowait()) 690 p << " nowait"; 691 692 if (auto ordered = op.ordered_val()) { 693 p << " ordered(" << ordered << ")"; 694 } 695 696 if (op.inclusive()) { 697 p << " inclusive"; 698 } 699 700 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 701 } 702 703 //===----------------------------------------------------------------------===// 704 // WsLoopOp 705 //===----------------------------------------------------------------------===// 706 707 void WsLoopOp::build(OpBuilder &builder, OperationState &state, 708 ValueRange lowerBound, ValueRange upperBound, 709 ValueRange step, ArrayRef<NamedAttribute> attributes) { 710 build(builder, state, TypeRange(), lowerBound, upperBound, step, 711 /*private_vars=*/ValueRange(), 712 /*firstprivate_vars=*/ValueRange(), /*lastprivate_vars=*/ValueRange(), 713 /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(), 714 /*schedule_val=*/nullptr, /*schedule_chunk_var=*/nullptr, 715 /*collapse_val=*/nullptr, 716 /*nowait=*/nullptr, /*ordered_val=*/nullptr, /*order_val=*/nullptr, 717 /*inclusive=*/nullptr, /*buildBody=*/false); 718 state.addAttributes(attributes); 719 } 720 721 void WsLoopOp::build(OpBuilder &, OperationState &state, TypeRange resultTypes, 722 ValueRange operands, ArrayRef<NamedAttribute> attributes) { 723 state.addOperands(operands); 724 state.addAttributes(attributes); 725 (void)state.addRegion(); 726 assert(resultTypes.size() == 0u && "mismatched number of return types"); 727 state.addTypes(resultTypes); 728 } 729 730 void WsLoopOp::build(OpBuilder &builder, OperationState &result, 731 TypeRange typeRange, ValueRange lowerBounds, 732 ValueRange upperBounds, ValueRange steps, 733 ValueRange privateVars, ValueRange firstprivateVars, 734 ValueRange lastprivateVars, ValueRange linearVars, 735 ValueRange linearStepVars, StringAttr scheduleVal, 736 Value scheduleChunkVar, IntegerAttr collapseVal, 737 UnitAttr nowait, IntegerAttr orderedVal, 738 StringAttr orderVal, UnitAttr inclusive, bool buildBody) { 739 result.addOperands(lowerBounds); 740 result.addOperands(upperBounds); 741 result.addOperands(steps); 742 result.addOperands(privateVars); 743 result.addOperands(firstprivateVars); 744 result.addOperands(linearVars); 745 result.addOperands(linearStepVars); 746 if (scheduleChunkVar) 747 result.addOperands(scheduleChunkVar); 748 749 if (scheduleVal) 750 result.addAttribute("schedule_val", scheduleVal); 751 if (collapseVal) 752 result.addAttribute("collapse_val", collapseVal); 753 if (nowait) 754 result.addAttribute("nowait", nowait); 755 if (orderedVal) 756 result.addAttribute("ordered_val", orderedVal); 757 if (orderVal) 758 result.addAttribute("order", orderVal); 759 if (inclusive) 760 result.addAttribute("inclusive", inclusive); 761 result.addAttribute( 762 WsLoopOp::getOperandSegmentSizeAttr(), 763 builder.getI32VectorAttr( 764 {static_cast<int32_t>(lowerBounds.size()), 765 static_cast<int32_t>(upperBounds.size()), 766 static_cast<int32_t>(steps.size()), 767 static_cast<int32_t>(privateVars.size()), 768 static_cast<int32_t>(firstprivateVars.size()), 769 static_cast<int32_t>(lastprivateVars.size()), 770 static_cast<int32_t>(linearVars.size()), 771 static_cast<int32_t>(linearStepVars.size()), 772 static_cast<int32_t>(scheduleChunkVar != nullptr ? 1 : 0)})); 773 774 Region *bodyRegion = result.addRegion(); 775 if (buildBody) { 776 OpBuilder::InsertionGuard guard(builder); 777 unsigned numIVs = steps.size(); 778 SmallVector<Type, 8> argTypes(numIVs, steps.getType().front()); 779 builder.createBlock(bodyRegion, {}, argTypes); 780 } 781 } 782 783 #define GET_OP_CLASSES 784 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 785