1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the OpenMP dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/IR/Attributes.h" 17 #include "mlir/IR/OpImplementation.h" 18 #include "mlir/IR/OperationSupport.h" 19 20 #include "llvm/ADT/SmallString.h" 21 #include "llvm/ADT/StringExtras.h" 22 #include "llvm/ADT/StringRef.h" 23 #include "llvm/ADT/StringSwitch.h" 24 #include <cstddef> 25 26 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc" 27 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc" 28 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc" 29 30 using namespace mlir; 31 using namespace mlir::omp; 32 33 namespace { 34 /// Model for pointer-like types that already provide a `getElementType` method. 35 template <typename T> 36 struct PointerLikeModel 37 : public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> { 38 Type getElementType(Type pointer) const { 39 return pointer.cast<T>().getElementType(); 40 } 41 }; 42 } // end namespace 43 44 void OpenMPDialect::initialize() { 45 addOperations< 46 #define GET_OP_LIST 47 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 48 >(); 49 50 LLVM::LLVMPointerType::attachInterface< 51 PointerLikeModel<LLVM::LLVMPointerType>>(*getContext()); 52 MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext()); 53 } 54 55 //===----------------------------------------------------------------------===// 56 // ParallelOp 57 //===----------------------------------------------------------------------===// 58 59 void ParallelOp::build(OpBuilder &builder, OperationState &state, 60 ArrayRef<NamedAttribute> attributes) { 61 ParallelOp::build( 62 builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr, 63 /*default_val=*/nullptr, /*private_vars=*/ValueRange(), 64 /*firstprivate_vars=*/ValueRange(), /*shared_vars=*/ValueRange(), 65 /*copyin_vars=*/ValueRange(), /*allocate_vars=*/ValueRange(), 66 /*allocators_vars=*/ValueRange(), /*proc_bind_val=*/nullptr); 67 state.addAttributes(attributes); 68 } 69 70 /// Parse a list of operands with types. 71 /// 72 /// operand-and-type-list ::= `(` ssa-id-and-type-list `)` 73 /// ssa-id-and-type-list ::= ssa-id-and-type | 74 /// ssa-id-and-type `,` ssa-id-and-type-list 75 /// ssa-id-and-type ::= ssa-id `:` type 76 static ParseResult 77 parseOperandAndTypeList(OpAsmParser &parser, 78 SmallVectorImpl<OpAsmParser::OperandType> &operands, 79 SmallVectorImpl<Type> &types) { 80 if (parser.parseLParen()) 81 return failure(); 82 83 do { 84 OpAsmParser::OperandType operand; 85 Type type; 86 if (parser.parseOperand(operand) || parser.parseColonType(type)) 87 return failure(); 88 operands.push_back(operand); 89 types.push_back(type); 90 } while (succeeded(parser.parseOptionalComma())); 91 92 if (parser.parseRParen()) 93 return failure(); 94 95 return success(); 96 } 97 98 /// Parse an allocate clause with allocators and a list of operands with types. 99 /// 100 /// operand-and-type-list ::= `(` allocate-operand-list `)` 101 /// allocate-operand-list :: = allocate-operand | 102 /// allocator-operand `,` allocate-operand-list 103 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type 104 /// ssa-id-and-type ::= ssa-id `:` type 105 static ParseResult parseAllocateAndAllocator( 106 OpAsmParser &parser, 107 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate, 108 SmallVectorImpl<Type> &typesAllocate, 109 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator, 110 SmallVectorImpl<Type> &typesAllocator) { 111 if (parser.parseLParen()) 112 return failure(); 113 114 do { 115 OpAsmParser::OperandType operand; 116 Type type; 117 118 if (parser.parseOperand(operand) || parser.parseColonType(type)) 119 return failure(); 120 operandsAllocator.push_back(operand); 121 typesAllocator.push_back(type); 122 if (parser.parseArrow()) 123 return failure(); 124 if (parser.parseOperand(operand) || parser.parseColonType(type)) 125 return failure(); 126 127 operandsAllocate.push_back(operand); 128 typesAllocate.push_back(type); 129 } while (succeeded(parser.parseOptionalComma())); 130 131 if (parser.parseRParen()) 132 return failure(); 133 134 return success(); 135 } 136 137 static LogicalResult verifyParallelOp(ParallelOp op) { 138 if (op.allocate_vars().size() != op.allocators_vars().size()) 139 return op.emitError( 140 "expected equal sizes for allocate and allocator variables"); 141 return success(); 142 } 143 144 static void printParallelOp(OpAsmPrinter &p, ParallelOp op) { 145 if (auto ifCond = op.if_expr_var()) 146 p << " if(" << ifCond << " : " << ifCond.getType() << ")"; 147 148 if (auto threads = op.num_threads_var()) 149 p << " num_threads(" << threads << " : " << threads.getType() << ")"; 150 151 // Print private, firstprivate, shared and copyin parameters 152 auto printDataVars = [&p](StringRef name, OperandRange vars) { 153 if (vars.size()) { 154 p << " " << name << "("; 155 for (unsigned i = 0; i < vars.size(); ++i) { 156 std::string separator = i == vars.size() - 1 ? ")" : ", "; 157 p << vars[i] << " : " << vars[i].getType() << separator; 158 } 159 } 160 }; 161 162 // Print allocator and allocate parameters 163 auto printAllocateAndAllocator = [&p](OperandRange varsAllocate, 164 OperandRange varsAllocator) { 165 if (varsAllocate.empty()) 166 return; 167 168 p << " allocate("; 169 for (unsigned i = 0; i < varsAllocate.size(); ++i) { 170 std::string separator = i == varsAllocate.size() - 1 ? ")" : ", "; 171 p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> "; 172 p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator; 173 } 174 }; 175 176 printDataVars("private", op.private_vars()); 177 printDataVars("firstprivate", op.firstprivate_vars()); 178 printDataVars("shared", op.shared_vars()); 179 printDataVars("copyin", op.copyin_vars()); 180 printAllocateAndAllocator(op.allocate_vars(), op.allocators_vars()); 181 182 if (auto def = op.default_val()) 183 p << " default(" << def->drop_front(3) << ")"; 184 185 if (auto bind = op.proc_bind_val()) 186 p << " proc_bind(" << bind << ")"; 187 188 p.printRegion(op.getRegion()); 189 } 190 191 /// Emit an error if the same clause is present more than once on an operation. 192 static ParseResult allowedOnce(OpAsmParser &parser, StringRef clause, 193 StringRef operation) { 194 return parser.emitError(parser.getNameLoc()) 195 << " at most one " << clause << " clause can appear on the " 196 << operation << " operation"; 197 } 198 199 /// Parses a parallel operation. 200 /// 201 /// operation ::= `omp.parallel` clause-list 202 /// clause-list ::= clause | clause clause-list 203 /// clause ::= if | numThreads | private | firstprivate | shared | copyin | 204 /// default | procBind 205 /// if ::= `if` `(` ssa-id `)` 206 /// numThreads ::= `num_threads` `(` ssa-id-and-type `)` 207 /// private ::= `private` operand-and-type-list 208 /// firstprivate ::= `firstprivate` operand-and-type-list 209 /// shared ::= `shared` operand-and-type-list 210 /// copyin ::= `copyin` operand-and-type-list 211 /// allocate ::= `allocate` operand-and-type `->` operand-and-type-list 212 /// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`) 213 /// procBind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)` 214 /// 215 /// Note that each clause can only appear once in the clase-list. 216 static ParseResult parseParallelOp(OpAsmParser &parser, 217 OperationState &result) { 218 std::pair<OpAsmParser::OperandType, Type> ifCond; 219 std::pair<OpAsmParser::OperandType, Type> numThreads; 220 SmallVector<OpAsmParser::OperandType, 4> privates; 221 SmallVector<Type, 4> privateTypes; 222 SmallVector<OpAsmParser::OperandType, 4> firstprivates; 223 SmallVector<Type, 4> firstprivateTypes; 224 SmallVector<OpAsmParser::OperandType, 4> shareds; 225 SmallVector<Type, 4> sharedTypes; 226 SmallVector<OpAsmParser::OperandType, 4> copyins; 227 SmallVector<Type, 4> copyinTypes; 228 SmallVector<OpAsmParser::OperandType, 4> allocates; 229 SmallVector<Type, 4> allocateTypes; 230 SmallVector<OpAsmParser::OperandType, 4> allocators; 231 SmallVector<Type, 4> allocatorTypes; 232 std::array<int, 8> segments{0, 0, 0, 0, 0, 0, 0, 0}; 233 StringRef keyword; 234 bool defaultVal = false; 235 bool procBind = false; 236 237 const int ifClausePos = 0; 238 const int numThreadsClausePos = 1; 239 const int privateClausePos = 2; 240 const int firstprivateClausePos = 3; 241 const int sharedClausePos = 4; 242 const int copyinClausePos = 5; 243 const int allocateClausePos = 6; 244 const int allocatorPos = 7; 245 const StringRef opName = result.name.getStringRef(); 246 247 while (succeeded(parser.parseOptionalKeyword(&keyword))) { 248 if (keyword == "if") { 249 // Fail if there was already another if condition. 250 if (segments[ifClausePos]) 251 return allowedOnce(parser, "if", opName); 252 if (parser.parseLParen() || parser.parseOperand(ifCond.first) || 253 parser.parseColonType(ifCond.second) || parser.parseRParen()) 254 return failure(); 255 segments[ifClausePos] = 1; 256 } else if (keyword == "num_threads") { 257 // Fail if there was already another num_threads clause. 258 if (segments[numThreadsClausePos]) 259 return allowedOnce(parser, "num_threads", opName); 260 if (parser.parseLParen() || parser.parseOperand(numThreads.first) || 261 parser.parseColonType(numThreads.second) || parser.parseRParen()) 262 return failure(); 263 segments[numThreadsClausePos] = 1; 264 } else if (keyword == "private") { 265 // Fail if there was already another private clause. 266 if (segments[privateClausePos]) 267 return allowedOnce(parser, "private", opName); 268 if (parseOperandAndTypeList(parser, privates, privateTypes)) 269 return failure(); 270 segments[privateClausePos] = privates.size(); 271 } else if (keyword == "firstprivate") { 272 // Fail if there was already another firstprivate clause. 273 if (segments[firstprivateClausePos]) 274 return allowedOnce(parser, "firstprivate", opName); 275 if (parseOperandAndTypeList(parser, firstprivates, firstprivateTypes)) 276 return failure(); 277 segments[firstprivateClausePos] = firstprivates.size(); 278 } else if (keyword == "shared") { 279 // Fail if there was already another shared clause. 280 if (segments[sharedClausePos]) 281 return allowedOnce(parser, "shared", opName); 282 if (parseOperandAndTypeList(parser, shareds, sharedTypes)) 283 return failure(); 284 segments[sharedClausePos] = shareds.size(); 285 } else if (keyword == "copyin") { 286 // Fail if there was already another copyin clause. 287 if (segments[copyinClausePos]) 288 return allowedOnce(parser, "copyin", opName); 289 if (parseOperandAndTypeList(parser, copyins, copyinTypes)) 290 return failure(); 291 segments[copyinClausePos] = copyins.size(); 292 } else if (keyword == "allocate") { 293 // Fail if there was already another allocate clause. 294 if (segments[allocateClausePos]) 295 return allowedOnce(parser, "allocate", opName); 296 if (parseAllocateAndAllocator(parser, allocates, allocateTypes, 297 allocators, allocatorTypes)) 298 return failure(); 299 segments[allocateClausePos] = allocates.size(); 300 segments[allocatorPos] = allocators.size(); 301 } else if (keyword == "default") { 302 // Fail if there was already another default clause. 303 if (defaultVal) 304 return allowedOnce(parser, "default", opName); 305 defaultVal = true; 306 StringRef defval; 307 if (parser.parseLParen() || parser.parseKeyword(&defval) || 308 parser.parseRParen()) 309 return failure(); 310 // The def prefix is required for the attribute as "private" is a keyword 311 // in C++. 312 auto attr = parser.getBuilder().getStringAttr("def" + defval); 313 result.addAttribute("default_val", attr); 314 } else if (keyword == "proc_bind") { 315 // Fail if there was already another proc_bind clause. 316 if (procBind) 317 return allowedOnce(parser, "proc_bind", opName); 318 procBind = true; 319 StringRef bind; 320 if (parser.parseLParen() || parser.parseKeyword(&bind) || 321 parser.parseRParen()) 322 return failure(); 323 auto attr = parser.getBuilder().getStringAttr(bind); 324 result.addAttribute("proc_bind_val", attr); 325 } else { 326 return parser.emitError(parser.getNameLoc()) 327 << keyword << " is not a valid clause for the " << opName 328 << " operation"; 329 } 330 } 331 332 // Add if parameter. 333 if (segments[ifClausePos] && 334 parser.resolveOperand(ifCond.first, ifCond.second, result.operands)) 335 return failure(); 336 337 // Add num_threads parameter. 338 if (segments[numThreadsClausePos] && 339 parser.resolveOperand(numThreads.first, numThreads.second, 340 result.operands)) 341 return failure(); 342 343 // Add private parameters. 344 if (segments[privateClausePos] && 345 parser.resolveOperands(privates, privateTypes, privates[0].location, 346 result.operands)) 347 return failure(); 348 349 // Add firstprivate parameters. 350 if (segments[firstprivateClausePos] && 351 parser.resolveOperands(firstprivates, firstprivateTypes, 352 firstprivates[0].location, result.operands)) 353 return failure(); 354 355 // Add shared parameters. 356 if (segments[sharedClausePos] && 357 parser.resolveOperands(shareds, sharedTypes, shareds[0].location, 358 result.operands)) 359 return failure(); 360 361 // Add copyin parameters. 362 if (segments[copyinClausePos] && 363 parser.resolveOperands(copyins, copyinTypes, copyins[0].location, 364 result.operands)) 365 return failure(); 366 367 // Add allocate parameters. 368 if (segments[allocateClausePos] && 369 parser.resolveOperands(allocates, allocateTypes, allocates[0].location, 370 result.operands)) 371 return failure(); 372 373 // Add allocator parameters. 374 if (segments[allocatorPos] && 375 parser.resolveOperands(allocators, allocatorTypes, allocators[0].location, 376 result.operands)) 377 return failure(); 378 379 result.addAttribute("operand_segment_sizes", 380 parser.getBuilder().getI32VectorAttr(segments)); 381 382 Region *body = result.addRegion(); 383 SmallVector<OpAsmParser::OperandType, 4> regionArgs; 384 SmallVector<Type, 4> regionArgTypes; 385 if (parser.parseRegion(*body, regionArgs, regionArgTypes)) 386 return failure(); 387 return success(); 388 } 389 390 /// linear ::= `linear` `(` linear-list `)` 391 /// linear-list := linear-val | linear-val linear-list 392 /// linear-val := ssa-id-and-type `=` ssa-id-and-type 393 static ParseResult 394 parseLinearClause(OpAsmParser &parser, 395 SmallVectorImpl<OpAsmParser::OperandType> &vars, 396 SmallVectorImpl<Type> &types, 397 SmallVectorImpl<OpAsmParser::OperandType> &stepVars) { 398 if (parser.parseLParen()) 399 return failure(); 400 401 do { 402 OpAsmParser::OperandType var; 403 Type type; 404 OpAsmParser::OperandType stepVar; 405 if (parser.parseOperand(var) || parser.parseEqual() || 406 parser.parseOperand(stepVar) || parser.parseColonType(type)) 407 return failure(); 408 409 vars.push_back(var); 410 types.push_back(type); 411 stepVars.push_back(stepVar); 412 } while (succeeded(parser.parseOptionalComma())); 413 414 if (parser.parseRParen()) 415 return failure(); 416 417 return success(); 418 } 419 420 /// schedule ::= `schedule` `(` sched-list `)` 421 /// sched-list ::= sched-val | sched-val sched-list 422 /// sched-val ::= sched-with-chunk | sched-wo-chunk 423 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)? 424 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided` 425 /// sched-wo-chunk ::= `auto` | `runtime` 426 static ParseResult 427 parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule, 428 Optional<OpAsmParser::OperandType> &chunkSize) { 429 if (parser.parseLParen()) 430 return failure(); 431 432 StringRef keyword; 433 if (parser.parseKeyword(&keyword)) 434 return failure(); 435 436 schedule = keyword; 437 if (keyword == "static" || keyword == "dynamic" || keyword == "guided") { 438 if (succeeded(parser.parseOptionalEqual())) { 439 chunkSize = OpAsmParser::OperandType{}; 440 if (parser.parseOperand(*chunkSize)) 441 return failure(); 442 } else { 443 chunkSize = llvm::NoneType::None; 444 } 445 } else if (keyword == "auto" || keyword == "runtime") { 446 chunkSize = llvm::NoneType::None; 447 } else { 448 return parser.emitError(parser.getNameLoc()) << " expected schedule kind"; 449 } 450 451 if (parser.parseRParen()) 452 return failure(); 453 454 return success(); 455 } 456 457 /// reduction-init ::= `reduction` `(` reduction-entry-list `)` 458 /// reduction-entry-list ::= reduction-entry 459 /// | reduction-entry-list `,` reduction-entry 460 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type 461 static ParseResult 462 parseReductionVarList(OpAsmParser &parser, 463 SmallVectorImpl<SymbolRefAttr> &symbols, 464 SmallVectorImpl<OpAsmParser::OperandType> &operands, 465 SmallVectorImpl<Type> &types) { 466 if (failed(parser.parseLParen())) 467 return failure(); 468 469 do { 470 if (parser.parseAttribute(symbols.emplace_back()) || parser.parseArrow() || 471 parser.parseOperand(operands.emplace_back()) || 472 parser.parseColonType(types.emplace_back())) 473 return failure(); 474 } while (succeeded(parser.parseOptionalComma())); 475 return parser.parseRParen(); 476 } 477 478 /// Parses an OpenMP Workshare Loop operation 479 /// 480 /// operation ::= `omp.wsloop` loop-control clause-list 481 /// loop-control ::= `(` ssa-id-list `)` `:` type `=` loop-bounds 482 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` steps 483 /// steps := `step` `(`ssa-id-list`)` 484 /// clause-list ::= clause | empty | clause-list 485 /// clause ::= private | firstprivate | lastprivate | linear | schedule | 486 // collapse | nowait | ordered | order | inclusive 487 /// private ::= `private` `(` ssa-id-and-type-list `)` 488 /// firstprivate ::= `firstprivate` `(` ssa-id-and-type-list `)` 489 /// lastprivate ::= `lastprivate` `(` ssa-id-and-type-list `)` 490 /// linear ::= `linear` `(` linear-list `)` 491 /// schedule ::= `schedule` `(` sched-list `)` 492 /// collapse ::= `collapse` `(` ssa-id-and-type `)` 493 /// nowait ::= `nowait` 494 /// ordered ::= `ordered` `(` ssa-id-and-type `)` 495 /// order ::= `order` `(` `concurrent` `)` 496 /// inclusive ::= `inclusive` 497 /// 498 static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) { 499 Type loopVarType; 500 int numIVs; 501 502 // Parse an opening `(` followed by induction variables followed by `)` 503 SmallVector<OpAsmParser::OperandType> ivs; 504 if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1, 505 OpAsmParser::Delimiter::Paren)) 506 return failure(); 507 508 numIVs = static_cast<int>(ivs.size()); 509 510 if (parser.parseColonType(loopVarType)) 511 return failure(); 512 513 // Parse loop bounds. 514 SmallVector<OpAsmParser::OperandType> lower; 515 if (parser.parseEqual() || 516 parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) || 517 parser.resolveOperands(lower, loopVarType, result.operands)) 518 return failure(); 519 520 SmallVector<OpAsmParser::OperandType> upper; 521 if (parser.parseKeyword("to") || 522 parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) || 523 parser.resolveOperands(upper, loopVarType, result.operands)) 524 return failure(); 525 526 // Parse step values. 527 SmallVector<OpAsmParser::OperandType> steps; 528 if (parser.parseKeyword("step") || 529 parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) || 530 parser.resolveOperands(steps, loopVarType, result.operands)) 531 return failure(); 532 533 SmallVector<OpAsmParser::OperandType> privates; 534 SmallVector<Type> privateTypes; 535 SmallVector<OpAsmParser::OperandType> firstprivates; 536 SmallVector<Type> firstprivateTypes; 537 SmallVector<OpAsmParser::OperandType> lastprivates; 538 SmallVector<Type> lastprivateTypes; 539 SmallVector<OpAsmParser::OperandType> linears; 540 SmallVector<Type> linearTypes; 541 SmallVector<OpAsmParser::OperandType> linearSteps; 542 SmallVector<SymbolRefAttr> reductionSymbols; 543 SmallVector<OpAsmParser::OperandType> reductionVars; 544 SmallVector<Type> reductionVarTypes; 545 SmallString<8> schedule; 546 Optional<OpAsmParser::OperandType> scheduleChunkSize; 547 548 const StringRef opName = result.name.getStringRef(); 549 StringRef keyword; 550 551 enum SegmentPos { 552 lbPos = 0, 553 ubPos, 554 stepPos, 555 privateClausePos, 556 firstprivateClausePos, 557 lastprivateClausePos, 558 linearClausePos, 559 linearStepPos, 560 reductionVarPos, 561 scheduleClausePos, 562 }; 563 std::array<int, 10> segments{numIVs, numIVs, numIVs, 0, 0, 0, 0, 0, 0, 0}; 564 565 while (succeeded(parser.parseOptionalKeyword(&keyword))) { 566 if (keyword == "private") { 567 if (segments[privateClausePos]) 568 return allowedOnce(parser, "private", opName); 569 if (parseOperandAndTypeList(parser, privates, privateTypes)) 570 return failure(); 571 segments[privateClausePos] = privates.size(); 572 } else if (keyword == "firstprivate") { 573 // fail if there was already another firstprivate clause 574 if (segments[firstprivateClausePos]) 575 return allowedOnce(parser, "firstprivate", opName); 576 if (parseOperandAndTypeList(parser, firstprivates, firstprivateTypes)) 577 return failure(); 578 segments[firstprivateClausePos] = firstprivates.size(); 579 } else if (keyword == "lastprivate") { 580 // fail if there was already another shared clause 581 if (segments[lastprivateClausePos]) 582 return allowedOnce(parser, "lastprivate", opName); 583 if (parseOperandAndTypeList(parser, lastprivates, lastprivateTypes)) 584 return failure(); 585 segments[lastprivateClausePos] = lastprivates.size(); 586 } else if (keyword == "linear") { 587 // fail if there was already another linear clause 588 if (segments[linearClausePos]) 589 return allowedOnce(parser, "linear", opName); 590 if (parseLinearClause(parser, linears, linearTypes, linearSteps)) 591 return failure(); 592 segments[linearClausePos] = linears.size(); 593 segments[linearStepPos] = linearSteps.size(); 594 } else if (keyword == "schedule") { 595 if (!schedule.empty()) 596 return allowedOnce(parser, "schedule", opName); 597 if (parseScheduleClause(parser, schedule, scheduleChunkSize)) 598 return failure(); 599 if (scheduleChunkSize) { 600 segments[scheduleClausePos] = 1; 601 } 602 } else if (keyword == "collapse") { 603 auto type = parser.getBuilder().getI64Type(); 604 mlir::IntegerAttr attr; 605 if (parser.parseLParen() || parser.parseAttribute(attr, type) || 606 parser.parseRParen()) 607 return failure(); 608 result.addAttribute("collapse_val", attr); 609 } else if (keyword == "nowait") { 610 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 611 result.addAttribute("nowait", attr); 612 } else if (keyword == "ordered") { 613 mlir::IntegerAttr attr; 614 if (succeeded(parser.parseOptionalLParen())) { 615 auto type = parser.getBuilder().getI64Type(); 616 if (parser.parseAttribute(attr, type)) 617 return failure(); 618 if (parser.parseRParen()) 619 return failure(); 620 } else { 621 // Use 0 to represent no ordered parameter was specified 622 attr = parser.getBuilder().getI64IntegerAttr(0); 623 } 624 result.addAttribute("ordered_val", attr); 625 } else if (keyword == "order") { 626 StringRef order; 627 if (parser.parseLParen() || parser.parseKeyword(&order) || 628 parser.parseRParen()) 629 return failure(); 630 auto attr = parser.getBuilder().getStringAttr(order); 631 result.addAttribute("order", attr); 632 } else if (keyword == "inclusive") { 633 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 634 result.addAttribute("inclusive", attr); 635 } else if (keyword == "reduction") { 636 if (segments[reductionVarPos]) 637 return allowedOnce(parser, "reduction", opName); 638 if (failed(parseReductionVarList(parser, reductionSymbols, reductionVars, 639 reductionVarTypes))) 640 return failure(); 641 segments[reductionVarPos] = reductionVars.size(); 642 } 643 } 644 645 if (segments[privateClausePos]) { 646 parser.resolveOperands(privates, privateTypes, privates[0].location, 647 result.operands); 648 } 649 650 if (segments[firstprivateClausePos]) { 651 parser.resolveOperands(firstprivates, firstprivateTypes, 652 firstprivates[0].location, result.operands); 653 } 654 655 if (segments[lastprivateClausePos]) { 656 parser.resolveOperands(lastprivates, lastprivateTypes, 657 lastprivates[0].location, result.operands); 658 } 659 660 if (segments[linearClausePos]) { 661 parser.resolveOperands(linears, linearTypes, linears[0].location, 662 result.operands); 663 auto linearStepType = parser.getBuilder().getI32Type(); 664 SmallVector<Type> linearStepTypes(linearSteps.size(), linearStepType); 665 parser.resolveOperands(linearSteps, linearStepTypes, 666 linearSteps[0].location, result.operands); 667 } 668 669 if (segments[reductionVarPos]) { 670 if (failed(parser.resolveOperands(reductionVars, reductionVarTypes, 671 parser.getNameLoc(), result.operands))) { 672 return failure(); 673 } 674 SmallVector<Attribute> reductions(reductionSymbols.begin(), 675 reductionSymbols.end()); 676 result.addAttribute("reductions", 677 parser.getBuilder().getArrayAttr(reductions)); 678 } 679 680 if (!schedule.empty()) { 681 schedule[0] = llvm::toUpper(schedule[0]); 682 auto attr = parser.getBuilder().getStringAttr(schedule); 683 result.addAttribute("schedule_val", attr); 684 if (scheduleChunkSize) { 685 auto chunkSizeType = parser.getBuilder().getI32Type(); 686 parser.resolveOperand(*scheduleChunkSize, chunkSizeType, result.operands); 687 } 688 } 689 690 result.addAttribute("operand_segment_sizes", 691 parser.getBuilder().getI32VectorAttr(segments)); 692 693 // Now parse the body. 694 Region *body = result.addRegion(); 695 SmallVector<Type> ivTypes(numIVs, loopVarType); 696 SmallVector<OpAsmParser::OperandType> blockArgs(ivs); 697 if (parser.parseRegion(*body, blockArgs, ivTypes)) 698 return failure(); 699 return success(); 700 } 701 702 static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) { 703 auto args = op.getRegion().front().getArguments(); 704 p << " (" << args << ") : " << args[0].getType() << " = (" << op.lowerBound() 705 << ") to (" << op.upperBound() << ") step (" << op.step() << ")"; 706 707 // Print private, firstprivate, shared and copyin parameters 708 auto printDataVars = [&p](StringRef name, OperandRange vars) { 709 if (vars.empty()) 710 return; 711 712 p << " " << name << "("; 713 llvm::interleaveComma( 714 vars, p, [&](const Value &v) { p << v << " : " << v.getType(); }); 715 p << ")"; 716 }; 717 printDataVars("private", op.private_vars()); 718 printDataVars("firstprivate", op.firstprivate_vars()); 719 printDataVars("lastprivate", op.lastprivate_vars()); 720 721 auto linearVars = op.linear_vars(); 722 auto linearVarsSize = linearVars.size(); 723 if (linearVarsSize) { 724 p << " " 725 << "linear" 726 << "("; 727 for (unsigned i = 0; i < linearVarsSize; ++i) { 728 std::string separator = i == linearVarsSize - 1 ? ")" : ", "; 729 p << linearVars[i]; 730 if (op.linear_step_vars().size() > i) 731 p << " = " << op.linear_step_vars()[i]; 732 p << " : " << linearVars[i].getType() << separator; 733 } 734 } 735 736 if (auto sched = op.schedule_val()) { 737 auto schedLower = sched->lower(); 738 p << " schedule(" << schedLower; 739 if (auto chunk = op.schedule_chunk_var()) { 740 p << " = " << chunk; 741 } 742 p << ")"; 743 } 744 745 if (auto collapse = op.collapse_val()) 746 p << " collapse(" << collapse << ")"; 747 748 if (op.nowait()) 749 p << " nowait"; 750 751 if (auto ordered = op.ordered_val()) { 752 p << " ordered(" << ordered << ")"; 753 } 754 755 if (!op.reduction_vars().empty()) { 756 p << " reduction("; 757 for (unsigned i = 0, e = op.getNumReductionVars(); i < e; ++i) { 758 if (i != 0) 759 p << ", "; 760 p << (*op.reductions())[i] << " -> " << op.reduction_vars()[i] << " : " 761 << op.reduction_vars()[i].getType(); 762 } 763 p << ")"; 764 } 765 766 if (op.inclusive()) { 767 p << " inclusive"; 768 } 769 770 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 771 } 772 773 //===----------------------------------------------------------------------===// 774 // ReductionOp 775 //===----------------------------------------------------------------------===// 776 777 static ParseResult parseAtomicReductionRegion(OpAsmParser &parser, 778 Region ®ion) { 779 if (parser.parseOptionalKeyword("atomic")) 780 return success(); 781 return parser.parseRegion(region); 782 } 783 784 static void printAtomicReductionRegion(OpAsmPrinter &printer, 785 ReductionDeclareOp op, Region ®ion) { 786 if (region.empty()) 787 return; 788 printer << "atomic "; 789 printer.printRegion(region); 790 } 791 792 static LogicalResult verifyReductionDeclareOp(ReductionDeclareOp op) { 793 if (op.initializerRegion().empty()) 794 return op.emitOpError() << "expects non-empty initializer region"; 795 Block &initializerEntryBlock = op.initializerRegion().front(); 796 if (initializerEntryBlock.getNumArguments() != 1 || 797 initializerEntryBlock.getArgument(0).getType() != op.type()) { 798 return op.emitOpError() << "expects initializer region with one argument " 799 "of the reduction type"; 800 } 801 802 for (YieldOp yieldOp : op.initializerRegion().getOps<YieldOp>()) { 803 if (yieldOp.results().size() != 1 || 804 yieldOp.results().getTypes()[0] != op.type()) 805 return op.emitOpError() << "expects initializer region to yield a value " 806 "of the reduction type"; 807 } 808 809 if (op.reductionRegion().empty()) 810 return op.emitOpError() << "expects non-empty reduction region"; 811 Block &reductionEntryBlock = op.reductionRegion().front(); 812 if (reductionEntryBlock.getNumArguments() != 2 || 813 reductionEntryBlock.getArgumentTypes()[0] != 814 reductionEntryBlock.getArgumentTypes()[1] || 815 reductionEntryBlock.getArgumentTypes()[0] != op.type()) 816 return op.emitOpError() << "expects reduction region with two arguments of " 817 "the reduction type"; 818 for (YieldOp yieldOp : op.reductionRegion().getOps<YieldOp>()) { 819 if (yieldOp.results().size() != 1 || 820 yieldOp.results().getTypes()[0] != op.type()) 821 return op.emitOpError() << "expects reduction region to yield a value " 822 "of the reduction type"; 823 } 824 825 if (op.atomicReductionRegion().empty()) 826 return success(); 827 828 Block &atomicReductionEntryBlock = op.atomicReductionRegion().front(); 829 if (atomicReductionEntryBlock.getNumArguments() != 2 || 830 atomicReductionEntryBlock.getArgumentTypes()[0] != 831 atomicReductionEntryBlock.getArgumentTypes()[1]) 832 return op.emitOpError() << "expects atomic reduction region with two " 833 "arguments of the same type"; 834 auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0] 835 .dyn_cast<PointerLikeType>(); 836 if (!ptrType || ptrType.getElementType() != op.type()) 837 return op.emitOpError() << "expects atomic reduction region arguments to " 838 "be accumulators containing the reduction type"; 839 return success(); 840 } 841 842 static LogicalResult verifyReductionOp(ReductionOp op) { 843 // TODO: generalize this to an op interface when there is more than one op 844 // that supports reductions. 845 auto container = op->getParentOfType<WsLoopOp>(); 846 for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i) 847 if (container.reduction_vars()[i] == op.accumulator()) 848 return success(); 849 850 return op.emitOpError() << "the accumulator is not used by the parent"; 851 } 852 853 //===----------------------------------------------------------------------===// 854 // WsLoopOp 855 //===----------------------------------------------------------------------===// 856 857 void WsLoopOp::build(OpBuilder &builder, OperationState &state, 858 ValueRange lowerBound, ValueRange upperBound, 859 ValueRange step, ArrayRef<NamedAttribute> attributes) { 860 build(builder, state, TypeRange(), lowerBound, upperBound, step, 861 /*private_vars=*/ValueRange(), 862 /*firstprivate_vars=*/ValueRange(), /*lastprivate_vars=*/ValueRange(), 863 /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(), 864 /*reduction_vars=*/ValueRange(), /*schedule_val=*/nullptr, 865 /*schedule_chunk_var=*/nullptr, /*collapse_val=*/nullptr, 866 /*nowait=*/nullptr, /*ordered_val=*/nullptr, /*order_val=*/nullptr, 867 /*inclusive=*/nullptr, /*buildBody=*/false); 868 state.addAttributes(attributes); 869 } 870 871 void WsLoopOp::build(OpBuilder &, OperationState &state, TypeRange resultTypes, 872 ValueRange operands, ArrayRef<NamedAttribute> attributes) { 873 state.addOperands(operands); 874 state.addAttributes(attributes); 875 (void)state.addRegion(); 876 assert(resultTypes.empty() && "mismatched number of return types"); 877 state.addTypes(resultTypes); 878 } 879 880 void WsLoopOp::build(OpBuilder &builder, OperationState &result, 881 TypeRange typeRange, ValueRange lowerBounds, 882 ValueRange upperBounds, ValueRange steps, 883 ValueRange privateVars, ValueRange firstprivateVars, 884 ValueRange lastprivateVars, ValueRange linearVars, 885 ValueRange linearStepVars, ValueRange reductionVars, 886 StringAttr scheduleVal, Value scheduleChunkVar, 887 IntegerAttr collapseVal, UnitAttr nowait, 888 IntegerAttr orderedVal, StringAttr orderVal, 889 UnitAttr inclusive, bool buildBody) { 890 result.addOperands(lowerBounds); 891 result.addOperands(upperBounds); 892 result.addOperands(steps); 893 result.addOperands(privateVars); 894 result.addOperands(firstprivateVars); 895 result.addOperands(linearVars); 896 result.addOperands(linearStepVars); 897 if (scheduleChunkVar) 898 result.addOperands(scheduleChunkVar); 899 900 if (scheduleVal) 901 result.addAttribute("schedule_val", scheduleVal); 902 if (collapseVal) 903 result.addAttribute("collapse_val", collapseVal); 904 if (nowait) 905 result.addAttribute("nowait", nowait); 906 if (orderedVal) 907 result.addAttribute("ordered_val", orderedVal); 908 if (orderVal) 909 result.addAttribute("order", orderVal); 910 if (inclusive) 911 result.addAttribute("inclusive", inclusive); 912 result.addAttribute( 913 WsLoopOp::getOperandSegmentSizeAttr(), 914 builder.getI32VectorAttr( 915 {static_cast<int32_t>(lowerBounds.size()), 916 static_cast<int32_t>(upperBounds.size()), 917 static_cast<int32_t>(steps.size()), 918 static_cast<int32_t>(privateVars.size()), 919 static_cast<int32_t>(firstprivateVars.size()), 920 static_cast<int32_t>(lastprivateVars.size()), 921 static_cast<int32_t>(linearVars.size()), 922 static_cast<int32_t>(linearStepVars.size()), 923 static_cast<int32_t>(reductionVars.size()), 924 static_cast<int32_t>(scheduleChunkVar != nullptr ? 1 : 0)})); 925 926 Region *bodyRegion = result.addRegion(); 927 if (buildBody) { 928 OpBuilder::InsertionGuard guard(builder); 929 unsigned numIVs = steps.size(); 930 SmallVector<Type, 8> argTypes(numIVs, steps.getType().front()); 931 builder.createBlock(bodyRegion, {}, argTypes); 932 } 933 } 934 935 static LogicalResult verifyWsLoopOp(WsLoopOp op) { 936 if (op.getNumReductionVars() != 0) { 937 if (!op.reductions() || 938 op.reductions()->size() != op.getNumReductionVars()) { 939 return op.emitOpError() << "expected as many reduction symbol references " 940 "as reduction variables"; 941 } 942 } else { 943 if (op.reductions()) 944 return op.emitOpError() << "unexpected reduction symbol references"; 945 return success(); 946 } 947 948 DenseSet<Value> accumulators; 949 for (auto args : llvm::zip(op.reduction_vars(), *op.reductions())) { 950 Value accum = std::get<0>(args); 951 if (!accumulators.insert(accum).second) { 952 return op.emitOpError() << "accumulator variable used more than once"; 953 } 954 Type varType = accum.getType().cast<PointerLikeType>(); 955 auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>(); 956 auto decl = 957 SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef); 958 if (!decl) { 959 return op.emitOpError() << "expected symbol reference " << symbolRef 960 << " to point to a reduction declaration"; 961 } 962 963 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType) { 964 return op.emitOpError() 965 << "expected accumulator (" << varType 966 << ") to be the same type as reduction declaration (" 967 << decl.getAccumulatorType() << ")"; 968 } 969 } 970 971 return success(); 972 } 973 974 static LogicalResult verifyCriticalOp(CriticalOp op) { 975 if (!op.name().hasValue() && op.hint().hasValue() && 976 (op.hint().getValue() != SyncHintKind::none)) 977 return op.emitOpError() << "must specify a name unless the effect is as if " 978 "hint(none) is specified"; 979 980 if (op.nameAttr()) { 981 auto symbolRef = op.nameAttr().cast<SymbolRefAttr>(); 982 auto decl = 983 SymbolTable::lookupNearestSymbolFrom<CriticalDeclareOp>(op, symbolRef); 984 if (!decl) { 985 return op.emitOpError() << "expected symbol reference " << symbolRef 986 << " to point to a critical declaration"; 987 } 988 } 989 990 return success(); 991 } 992 993 #define GET_OP_CLASSES 994 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 995