1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the OpenMP dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 15 #include "mlir/Dialect/StandardOps/IR/Ops.h" 16 #include "mlir/IR/Attributes.h" 17 #include "mlir/IR/OpImplementation.h" 18 #include "mlir/IR/OperationSupport.h" 19 20 #include "llvm/ADT/SmallString.h" 21 #include "llvm/ADT/StringExtras.h" 22 #include "llvm/ADT/StringRef.h" 23 #include "llvm/ADT/StringSwitch.h" 24 #include <cstddef> 25 26 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc" 27 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc" 28 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc" 29 30 using namespace mlir; 31 using namespace mlir::omp; 32 33 namespace { 34 /// Model for pointer-like types that already provide a `getElementType` method. 35 template <typename T> 36 struct PointerLikeModel 37 : public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> { 38 Type getElementType(Type pointer) const { 39 return pointer.cast<T>().getElementType(); 40 } 41 }; 42 } // end namespace 43 44 void OpenMPDialect::initialize() { 45 addOperations< 46 #define GET_OP_LIST 47 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 48 >(); 49 50 LLVM::LLVMPointerType::attachInterface< 51 PointerLikeModel<LLVM::LLVMPointerType>>(*getContext()); 52 MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext()); 53 } 54 55 //===----------------------------------------------------------------------===// 56 // ParallelOp 57 //===----------------------------------------------------------------------===// 58 59 void ParallelOp::build(OpBuilder &builder, OperationState &state, 60 ArrayRef<NamedAttribute> attributes) { 61 ParallelOp::build( 62 builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr, 63 /*default_val=*/nullptr, /*private_vars=*/ValueRange(), 64 /*firstprivate_vars=*/ValueRange(), /*shared_vars=*/ValueRange(), 65 /*copyin_vars=*/ValueRange(), /*allocate_vars=*/ValueRange(), 66 /*allocators_vars=*/ValueRange(), /*proc_bind_val=*/nullptr); 67 state.addAttributes(attributes); 68 } 69 70 /// Parse a list of operands with types. 71 /// 72 /// operand-and-type-list ::= `(` ssa-id-and-type-list `)` 73 /// ssa-id-and-type-list ::= ssa-id-and-type | 74 /// ssa-id-and-type `,` ssa-id-and-type-list 75 /// ssa-id-and-type ::= ssa-id `:` type 76 static ParseResult 77 parseOperandAndTypeList(OpAsmParser &parser, 78 SmallVectorImpl<OpAsmParser::OperandType> &operands, 79 SmallVectorImpl<Type> &types) { 80 if (parser.parseLParen()) 81 return failure(); 82 83 do { 84 OpAsmParser::OperandType operand; 85 Type type; 86 if (parser.parseOperand(operand) || parser.parseColonType(type)) 87 return failure(); 88 operands.push_back(operand); 89 types.push_back(type); 90 } while (succeeded(parser.parseOptionalComma())); 91 92 if (parser.parseRParen()) 93 return failure(); 94 95 return success(); 96 } 97 98 /// Parse an allocate clause with allocators and a list of operands with types. 99 /// 100 /// operand-and-type-list ::= `(` allocate-operand-list `)` 101 /// allocate-operand-list :: = allocate-operand | 102 /// allocator-operand `,` allocate-operand-list 103 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type 104 /// ssa-id-and-type ::= ssa-id `:` type 105 static ParseResult parseAllocateAndAllocator( 106 OpAsmParser &parser, 107 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate, 108 SmallVectorImpl<Type> &typesAllocate, 109 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator, 110 SmallVectorImpl<Type> &typesAllocator) { 111 if (parser.parseLParen()) 112 return failure(); 113 114 do { 115 OpAsmParser::OperandType operand; 116 Type type; 117 118 if (parser.parseOperand(operand) || parser.parseColonType(type)) 119 return failure(); 120 operandsAllocator.push_back(operand); 121 typesAllocator.push_back(type); 122 if (parser.parseArrow()) 123 return failure(); 124 if (parser.parseOperand(operand) || parser.parseColonType(type)) 125 return failure(); 126 127 operandsAllocate.push_back(operand); 128 typesAllocate.push_back(type); 129 } while (succeeded(parser.parseOptionalComma())); 130 131 if (parser.parseRParen()) 132 return failure(); 133 134 return success(); 135 } 136 137 static LogicalResult verifyParallelOp(ParallelOp op) { 138 if (op.allocate_vars().size() != op.allocators_vars().size()) 139 return op.emitError( 140 "expected equal sizes for allocate and allocator variables"); 141 return success(); 142 } 143 144 static void printParallelOp(OpAsmPrinter &p, ParallelOp op) { 145 p << "omp.parallel"; 146 147 if (auto ifCond = op.if_expr_var()) 148 p << " if(" << ifCond << " : " << ifCond.getType() << ")"; 149 150 if (auto threads = op.num_threads_var()) 151 p << " num_threads(" << threads << " : " << threads.getType() << ")"; 152 153 // Print private, firstprivate, shared and copyin parameters 154 auto printDataVars = [&p](StringRef name, OperandRange vars) { 155 if (vars.size()) { 156 p << " " << name << "("; 157 for (unsigned i = 0; i < vars.size(); ++i) { 158 std::string separator = i == vars.size() - 1 ? ")" : ", "; 159 p << vars[i] << " : " << vars[i].getType() << separator; 160 } 161 } 162 }; 163 164 // Print allocator and allocate parameters 165 auto printAllocateAndAllocator = [&p](OperandRange varsAllocate, 166 OperandRange varsAllocator) { 167 if (varsAllocate.empty()) 168 return; 169 170 p << " allocate("; 171 for (unsigned i = 0; i < varsAllocate.size(); ++i) { 172 std::string separator = i == varsAllocate.size() - 1 ? ")" : ", "; 173 p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> "; 174 p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator; 175 } 176 }; 177 178 printDataVars("private", op.private_vars()); 179 printDataVars("firstprivate", op.firstprivate_vars()); 180 printDataVars("shared", op.shared_vars()); 181 printDataVars("copyin", op.copyin_vars()); 182 printAllocateAndAllocator(op.allocate_vars(), op.allocators_vars()); 183 184 if (auto def = op.default_val()) 185 p << " default(" << def->drop_front(3) << ")"; 186 187 if (auto bind = op.proc_bind_val()) 188 p << " proc_bind(" << bind << ")"; 189 190 p.printRegion(op.getRegion()); 191 } 192 193 /// Emit an error if the same clause is present more than once on an operation. 194 static ParseResult allowedOnce(OpAsmParser &parser, StringRef clause, 195 StringRef operation) { 196 return parser.emitError(parser.getNameLoc()) 197 << " at most one " << clause << " clause can appear on the " 198 << operation << " operation"; 199 } 200 201 /// Parses a parallel operation. 202 /// 203 /// operation ::= `omp.parallel` clause-list 204 /// clause-list ::= clause | clause clause-list 205 /// clause ::= if | numThreads | private | firstprivate | shared | copyin | 206 /// default | procBind 207 /// if ::= `if` `(` ssa-id `)` 208 /// numThreads ::= `num_threads` `(` ssa-id-and-type `)` 209 /// private ::= `private` operand-and-type-list 210 /// firstprivate ::= `firstprivate` operand-and-type-list 211 /// shared ::= `shared` operand-and-type-list 212 /// copyin ::= `copyin` operand-and-type-list 213 /// allocate ::= `allocate` operand-and-type `->` operand-and-type-list 214 /// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`) 215 /// procBind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)` 216 /// 217 /// Note that each clause can only appear once in the clase-list. 218 static ParseResult parseParallelOp(OpAsmParser &parser, 219 OperationState &result) { 220 std::pair<OpAsmParser::OperandType, Type> ifCond; 221 std::pair<OpAsmParser::OperandType, Type> numThreads; 222 SmallVector<OpAsmParser::OperandType, 4> privates; 223 SmallVector<Type, 4> privateTypes; 224 SmallVector<OpAsmParser::OperandType, 4> firstprivates; 225 SmallVector<Type, 4> firstprivateTypes; 226 SmallVector<OpAsmParser::OperandType, 4> shareds; 227 SmallVector<Type, 4> sharedTypes; 228 SmallVector<OpAsmParser::OperandType, 4> copyins; 229 SmallVector<Type, 4> copyinTypes; 230 SmallVector<OpAsmParser::OperandType, 4> allocates; 231 SmallVector<Type, 4> allocateTypes; 232 SmallVector<OpAsmParser::OperandType, 4> allocators; 233 SmallVector<Type, 4> allocatorTypes; 234 std::array<int, 8> segments{0, 0, 0, 0, 0, 0, 0, 0}; 235 StringRef keyword; 236 bool defaultVal = false; 237 bool procBind = false; 238 239 const int ifClausePos = 0; 240 const int numThreadsClausePos = 1; 241 const int privateClausePos = 2; 242 const int firstprivateClausePos = 3; 243 const int sharedClausePos = 4; 244 const int copyinClausePos = 5; 245 const int allocateClausePos = 6; 246 const int allocatorPos = 7; 247 const StringRef opName = result.name.getStringRef(); 248 249 while (succeeded(parser.parseOptionalKeyword(&keyword))) { 250 if (keyword == "if") { 251 // Fail if there was already another if condition. 252 if (segments[ifClausePos]) 253 return allowedOnce(parser, "if", opName); 254 if (parser.parseLParen() || parser.parseOperand(ifCond.first) || 255 parser.parseColonType(ifCond.second) || parser.parseRParen()) 256 return failure(); 257 segments[ifClausePos] = 1; 258 } else if (keyword == "num_threads") { 259 // Fail if there was already another num_threads clause. 260 if (segments[numThreadsClausePos]) 261 return allowedOnce(parser, "num_threads", opName); 262 if (parser.parseLParen() || parser.parseOperand(numThreads.first) || 263 parser.parseColonType(numThreads.second) || parser.parseRParen()) 264 return failure(); 265 segments[numThreadsClausePos] = 1; 266 } else if (keyword == "private") { 267 // Fail if there was already another private clause. 268 if (segments[privateClausePos]) 269 return allowedOnce(parser, "private", opName); 270 if (parseOperandAndTypeList(parser, privates, privateTypes)) 271 return failure(); 272 segments[privateClausePos] = privates.size(); 273 } else if (keyword == "firstprivate") { 274 // Fail if there was already another firstprivate clause. 275 if (segments[firstprivateClausePos]) 276 return allowedOnce(parser, "firstprivate", opName); 277 if (parseOperandAndTypeList(parser, firstprivates, firstprivateTypes)) 278 return failure(); 279 segments[firstprivateClausePos] = firstprivates.size(); 280 } else if (keyword == "shared") { 281 // Fail if there was already another shared clause. 282 if (segments[sharedClausePos]) 283 return allowedOnce(parser, "shared", opName); 284 if (parseOperandAndTypeList(parser, shareds, sharedTypes)) 285 return failure(); 286 segments[sharedClausePos] = shareds.size(); 287 } else if (keyword == "copyin") { 288 // Fail if there was already another copyin clause. 289 if (segments[copyinClausePos]) 290 return allowedOnce(parser, "copyin", opName); 291 if (parseOperandAndTypeList(parser, copyins, copyinTypes)) 292 return failure(); 293 segments[copyinClausePos] = copyins.size(); 294 } else if (keyword == "allocate") { 295 // Fail if there was already another allocate clause. 296 if (segments[allocateClausePos]) 297 return allowedOnce(parser, "allocate", opName); 298 if (parseAllocateAndAllocator(parser, allocates, allocateTypes, 299 allocators, allocatorTypes)) 300 return failure(); 301 segments[allocateClausePos] = allocates.size(); 302 segments[allocatorPos] = allocators.size(); 303 } else if (keyword == "default") { 304 // Fail if there was already another default clause. 305 if (defaultVal) 306 return allowedOnce(parser, "default", opName); 307 defaultVal = true; 308 StringRef defval; 309 if (parser.parseLParen() || parser.parseKeyword(&defval) || 310 parser.parseRParen()) 311 return failure(); 312 // The def prefix is required for the attribute as "private" is a keyword 313 // in C++. 314 auto attr = parser.getBuilder().getStringAttr("def" + defval); 315 result.addAttribute("default_val", attr); 316 } else if (keyword == "proc_bind") { 317 // Fail if there was already another proc_bind clause. 318 if (procBind) 319 return allowedOnce(parser, "proc_bind", opName); 320 procBind = true; 321 StringRef bind; 322 if (parser.parseLParen() || parser.parseKeyword(&bind) || 323 parser.parseRParen()) 324 return failure(); 325 auto attr = parser.getBuilder().getStringAttr(bind); 326 result.addAttribute("proc_bind_val", attr); 327 } else { 328 return parser.emitError(parser.getNameLoc()) 329 << keyword << " is not a valid clause for the " << opName 330 << " operation"; 331 } 332 } 333 334 // Add if parameter. 335 if (segments[ifClausePos] && 336 parser.resolveOperand(ifCond.first, ifCond.second, result.operands)) 337 return failure(); 338 339 // Add num_threads parameter. 340 if (segments[numThreadsClausePos] && 341 parser.resolveOperand(numThreads.first, numThreads.second, 342 result.operands)) 343 return failure(); 344 345 // Add private parameters. 346 if (segments[privateClausePos] && 347 parser.resolveOperands(privates, privateTypes, privates[0].location, 348 result.operands)) 349 return failure(); 350 351 // Add firstprivate parameters. 352 if (segments[firstprivateClausePos] && 353 parser.resolveOperands(firstprivates, firstprivateTypes, 354 firstprivates[0].location, result.operands)) 355 return failure(); 356 357 // Add shared parameters. 358 if (segments[sharedClausePos] && 359 parser.resolveOperands(shareds, sharedTypes, shareds[0].location, 360 result.operands)) 361 return failure(); 362 363 // Add copyin parameters. 364 if (segments[copyinClausePos] && 365 parser.resolveOperands(copyins, copyinTypes, copyins[0].location, 366 result.operands)) 367 return failure(); 368 369 // Add allocate parameters. 370 if (segments[allocateClausePos] && 371 parser.resolveOperands(allocates, allocateTypes, allocates[0].location, 372 result.operands)) 373 return failure(); 374 375 // Add allocator parameters. 376 if (segments[allocatorPos] && 377 parser.resolveOperands(allocators, allocatorTypes, allocators[0].location, 378 result.operands)) 379 return failure(); 380 381 result.addAttribute("operand_segment_sizes", 382 parser.getBuilder().getI32VectorAttr(segments)); 383 384 Region *body = result.addRegion(); 385 SmallVector<OpAsmParser::OperandType, 4> regionArgs; 386 SmallVector<Type, 4> regionArgTypes; 387 if (parser.parseRegion(*body, regionArgs, regionArgTypes)) 388 return failure(); 389 return success(); 390 } 391 392 /// linear ::= `linear` `(` linear-list `)` 393 /// linear-list := linear-val | linear-val linear-list 394 /// linear-val := ssa-id-and-type `=` ssa-id-and-type 395 static ParseResult 396 parseLinearClause(OpAsmParser &parser, 397 SmallVectorImpl<OpAsmParser::OperandType> &vars, 398 SmallVectorImpl<Type> &types, 399 SmallVectorImpl<OpAsmParser::OperandType> &stepVars) { 400 if (parser.parseLParen()) 401 return failure(); 402 403 do { 404 OpAsmParser::OperandType var; 405 Type type; 406 OpAsmParser::OperandType stepVar; 407 if (parser.parseOperand(var) || parser.parseEqual() || 408 parser.parseOperand(stepVar) || parser.parseColonType(type)) 409 return failure(); 410 411 vars.push_back(var); 412 types.push_back(type); 413 stepVars.push_back(stepVar); 414 } while (succeeded(parser.parseOptionalComma())); 415 416 if (parser.parseRParen()) 417 return failure(); 418 419 return success(); 420 } 421 422 /// schedule ::= `schedule` `(` sched-list `)` 423 /// sched-list ::= sched-val | sched-val sched-list 424 /// sched-val ::= sched-with-chunk | sched-wo-chunk 425 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)? 426 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided` 427 /// sched-wo-chunk ::= `auto` | `runtime` 428 static ParseResult 429 parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule, 430 Optional<OpAsmParser::OperandType> &chunkSize) { 431 if (parser.parseLParen()) 432 return failure(); 433 434 StringRef keyword; 435 if (parser.parseKeyword(&keyword)) 436 return failure(); 437 438 schedule = keyword; 439 if (keyword == "static" || keyword == "dynamic" || keyword == "guided") { 440 if (succeeded(parser.parseOptionalEqual())) { 441 chunkSize = OpAsmParser::OperandType{}; 442 if (parser.parseOperand(*chunkSize)) 443 return failure(); 444 } else { 445 chunkSize = llvm::NoneType::None; 446 } 447 } else if (keyword == "auto" || keyword == "runtime") { 448 chunkSize = llvm::NoneType::None; 449 } else { 450 return parser.emitError(parser.getNameLoc()) << " expected schedule kind"; 451 } 452 453 if (parser.parseRParen()) 454 return failure(); 455 456 return success(); 457 } 458 459 /// reduction-init ::= `reduction` `(` reduction-entry-list `)` 460 /// reduction-entry-list ::= reduction-entry 461 /// | reduction-entry-list `,` reduction-entry 462 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type 463 static ParseResult 464 parseReductionVarList(OpAsmParser &parser, 465 SmallVectorImpl<SymbolRefAttr> &symbols, 466 SmallVectorImpl<OpAsmParser::OperandType> &operands, 467 SmallVectorImpl<Type> &types) { 468 if (failed(parser.parseLParen())) 469 return failure(); 470 471 do { 472 if (parser.parseAttribute(symbols.emplace_back()) || parser.parseArrow() || 473 parser.parseOperand(operands.emplace_back()) || 474 parser.parseColonType(types.emplace_back())) 475 return failure(); 476 } while (succeeded(parser.parseOptionalComma())); 477 return parser.parseRParen(); 478 } 479 480 /// Parses an OpenMP Workshare Loop operation 481 /// 482 /// operation ::= `omp.wsloop` loop-control clause-list 483 /// loop-control ::= `(` ssa-id-list `)` `:` type `=` loop-bounds 484 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` steps 485 /// steps := `step` `(`ssa-id-list`)` 486 /// clause-list ::= clause | empty | clause-list 487 /// clause ::= private | firstprivate | lastprivate | linear | schedule | 488 // collapse | nowait | ordered | order | inclusive 489 /// private ::= `private` `(` ssa-id-and-type-list `)` 490 /// firstprivate ::= `firstprivate` `(` ssa-id-and-type-list `)` 491 /// lastprivate ::= `lastprivate` `(` ssa-id-and-type-list `)` 492 /// linear ::= `linear` `(` linear-list `)` 493 /// schedule ::= `schedule` `(` sched-list `)` 494 /// collapse ::= `collapse` `(` ssa-id-and-type `)` 495 /// nowait ::= `nowait` 496 /// ordered ::= `ordered` `(` ssa-id-and-type `)` 497 /// order ::= `order` `(` `concurrent` `)` 498 /// inclusive ::= `inclusive` 499 /// 500 static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) { 501 Type loopVarType; 502 int numIVs; 503 504 // Parse an opening `(` followed by induction variables followed by `)` 505 SmallVector<OpAsmParser::OperandType> ivs; 506 if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1, 507 OpAsmParser::Delimiter::Paren)) 508 return failure(); 509 510 numIVs = static_cast<int>(ivs.size()); 511 512 if (parser.parseColonType(loopVarType)) 513 return failure(); 514 515 // Parse loop bounds. 516 SmallVector<OpAsmParser::OperandType> lower; 517 if (parser.parseEqual() || 518 parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) || 519 parser.resolveOperands(lower, loopVarType, result.operands)) 520 return failure(); 521 522 SmallVector<OpAsmParser::OperandType> upper; 523 if (parser.parseKeyword("to") || 524 parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) || 525 parser.resolveOperands(upper, loopVarType, result.operands)) 526 return failure(); 527 528 // Parse step values. 529 SmallVector<OpAsmParser::OperandType> steps; 530 if (parser.parseKeyword("step") || 531 parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) || 532 parser.resolveOperands(steps, loopVarType, result.operands)) 533 return failure(); 534 535 SmallVector<OpAsmParser::OperandType> privates; 536 SmallVector<Type> privateTypes; 537 SmallVector<OpAsmParser::OperandType> firstprivates; 538 SmallVector<Type> firstprivateTypes; 539 SmallVector<OpAsmParser::OperandType> lastprivates; 540 SmallVector<Type> lastprivateTypes; 541 SmallVector<OpAsmParser::OperandType> linears; 542 SmallVector<Type> linearTypes; 543 SmallVector<OpAsmParser::OperandType> linearSteps; 544 SmallVector<SymbolRefAttr> reductionSymbols; 545 SmallVector<OpAsmParser::OperandType> reductionVars; 546 SmallVector<Type> reductionVarTypes; 547 SmallString<8> schedule; 548 Optional<OpAsmParser::OperandType> scheduleChunkSize; 549 550 const StringRef opName = result.name.getStringRef(); 551 StringRef keyword; 552 553 enum SegmentPos { 554 lbPos = 0, 555 ubPos, 556 stepPos, 557 privateClausePos, 558 firstprivateClausePos, 559 lastprivateClausePos, 560 linearClausePos, 561 linearStepPos, 562 reductionVarPos, 563 scheduleClausePos, 564 }; 565 std::array<int, 10> segments{numIVs, numIVs, numIVs, 0, 0, 0, 0, 0, 0, 0}; 566 567 while (succeeded(parser.parseOptionalKeyword(&keyword))) { 568 if (keyword == "private") { 569 if (segments[privateClausePos]) 570 return allowedOnce(parser, "private", opName); 571 if (parseOperandAndTypeList(parser, privates, privateTypes)) 572 return failure(); 573 segments[privateClausePos] = privates.size(); 574 } else if (keyword == "firstprivate") { 575 // fail if there was already another firstprivate clause 576 if (segments[firstprivateClausePos]) 577 return allowedOnce(parser, "firstprivate", opName); 578 if (parseOperandAndTypeList(parser, firstprivates, firstprivateTypes)) 579 return failure(); 580 segments[firstprivateClausePos] = firstprivates.size(); 581 } else if (keyword == "lastprivate") { 582 // fail if there was already another shared clause 583 if (segments[lastprivateClausePos]) 584 return allowedOnce(parser, "lastprivate", opName); 585 if (parseOperandAndTypeList(parser, lastprivates, lastprivateTypes)) 586 return failure(); 587 segments[lastprivateClausePos] = lastprivates.size(); 588 } else if (keyword == "linear") { 589 // fail if there was already another linear clause 590 if (segments[linearClausePos]) 591 return allowedOnce(parser, "linear", opName); 592 if (parseLinearClause(parser, linears, linearTypes, linearSteps)) 593 return failure(); 594 segments[linearClausePos] = linears.size(); 595 segments[linearStepPos] = linearSteps.size(); 596 } else if (keyword == "schedule") { 597 if (!schedule.empty()) 598 return allowedOnce(parser, "schedule", opName); 599 if (parseScheduleClause(parser, schedule, scheduleChunkSize)) 600 return failure(); 601 if (scheduleChunkSize) { 602 segments[scheduleClausePos] = 1; 603 } 604 } else if (keyword == "collapse") { 605 auto type = parser.getBuilder().getI64Type(); 606 mlir::IntegerAttr attr; 607 if (parser.parseLParen() || parser.parseAttribute(attr, type) || 608 parser.parseRParen()) 609 return failure(); 610 result.addAttribute("collapse_val", attr); 611 } else if (keyword == "nowait") { 612 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 613 result.addAttribute("nowait", attr); 614 } else if (keyword == "ordered") { 615 mlir::IntegerAttr attr; 616 if (succeeded(parser.parseOptionalLParen())) { 617 auto type = parser.getBuilder().getI64Type(); 618 if (parser.parseAttribute(attr, type)) 619 return failure(); 620 if (parser.parseRParen()) 621 return failure(); 622 } else { 623 // Use 0 to represent no ordered parameter was specified 624 attr = parser.getBuilder().getI64IntegerAttr(0); 625 } 626 result.addAttribute("ordered_val", attr); 627 } else if (keyword == "order") { 628 StringRef order; 629 if (parser.parseLParen() || parser.parseKeyword(&order) || 630 parser.parseRParen()) 631 return failure(); 632 auto attr = parser.getBuilder().getStringAttr(order); 633 result.addAttribute("order", attr); 634 } else if (keyword == "inclusive") { 635 auto attr = UnitAttr::get(parser.getBuilder().getContext()); 636 result.addAttribute("inclusive", attr); 637 } else if (keyword == "reduction") { 638 if (segments[reductionVarPos]) 639 return allowedOnce(parser, "reduction", opName); 640 if (failed(parseReductionVarList(parser, reductionSymbols, reductionVars, 641 reductionVarTypes))) 642 return failure(); 643 segments[reductionVarPos] = reductionVars.size(); 644 } 645 } 646 647 if (segments[privateClausePos]) { 648 parser.resolveOperands(privates, privateTypes, privates[0].location, 649 result.operands); 650 } 651 652 if (segments[firstprivateClausePos]) { 653 parser.resolveOperands(firstprivates, firstprivateTypes, 654 firstprivates[0].location, result.operands); 655 } 656 657 if (segments[lastprivateClausePos]) { 658 parser.resolveOperands(lastprivates, lastprivateTypes, 659 lastprivates[0].location, result.operands); 660 } 661 662 if (segments[linearClausePos]) { 663 parser.resolveOperands(linears, linearTypes, linears[0].location, 664 result.operands); 665 auto linearStepType = parser.getBuilder().getI32Type(); 666 SmallVector<Type> linearStepTypes(linearSteps.size(), linearStepType); 667 parser.resolveOperands(linearSteps, linearStepTypes, 668 linearSteps[0].location, result.operands); 669 } 670 671 if (segments[reductionVarPos]) { 672 if (failed(parser.resolveOperands(reductionVars, reductionVarTypes, 673 parser.getNameLoc(), result.operands))) { 674 return failure(); 675 } 676 SmallVector<Attribute> reductions(reductionSymbols.begin(), 677 reductionSymbols.end()); 678 result.addAttribute("reductions", 679 parser.getBuilder().getArrayAttr(reductions)); 680 } 681 682 if (!schedule.empty()) { 683 schedule[0] = llvm::toUpper(schedule[0]); 684 auto attr = parser.getBuilder().getStringAttr(schedule); 685 result.addAttribute("schedule_val", attr); 686 if (scheduleChunkSize) { 687 auto chunkSizeType = parser.getBuilder().getI32Type(); 688 parser.resolveOperand(*scheduleChunkSize, chunkSizeType, result.operands); 689 } 690 } 691 692 result.addAttribute("operand_segment_sizes", 693 parser.getBuilder().getI32VectorAttr(segments)); 694 695 // Now parse the body. 696 Region *body = result.addRegion(); 697 SmallVector<Type> ivTypes(numIVs, loopVarType); 698 SmallVector<OpAsmParser::OperandType> blockArgs(ivs); 699 if (parser.parseRegion(*body, blockArgs, ivTypes)) 700 return failure(); 701 return success(); 702 } 703 704 static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) { 705 auto args = op.getRegion().front().getArguments(); 706 p << op.getOperationName() << " (" << args << ") : " << args[0].getType() 707 << " = (" << op.lowerBound() << ") to (" << op.upperBound() << ") step (" 708 << op.step() << ")"; 709 710 // Print private, firstprivate, shared and copyin parameters 711 auto printDataVars = [&p](StringRef name, OperandRange vars) { 712 if (vars.empty()) 713 return; 714 715 p << " " << name << "("; 716 llvm::interleaveComma( 717 vars, p, [&](const Value &v) { p << v << " : " << v.getType(); }); 718 p << ")"; 719 }; 720 printDataVars("private", op.private_vars()); 721 printDataVars("firstprivate", op.firstprivate_vars()); 722 printDataVars("lastprivate", op.lastprivate_vars()); 723 724 auto linearVars = op.linear_vars(); 725 auto linearVarsSize = linearVars.size(); 726 if (linearVarsSize) { 727 p << " " 728 << "linear" 729 << "("; 730 for (unsigned i = 0; i < linearVarsSize; ++i) { 731 std::string separator = i == linearVarsSize - 1 ? ")" : ", "; 732 p << linearVars[i]; 733 if (op.linear_step_vars().size() > i) 734 p << " = " << op.linear_step_vars()[i]; 735 p << " : " << linearVars[i].getType() << separator; 736 } 737 } 738 739 if (auto sched = op.schedule_val()) { 740 auto schedLower = sched->lower(); 741 p << " schedule(" << schedLower; 742 if (auto chunk = op.schedule_chunk_var()) { 743 p << " = " << chunk; 744 } 745 p << ")"; 746 } 747 748 if (auto collapse = op.collapse_val()) 749 p << " collapse(" << collapse << ")"; 750 751 if (op.nowait()) 752 p << " nowait"; 753 754 if (auto ordered = op.ordered_val()) { 755 p << " ordered(" << ordered << ")"; 756 } 757 758 if (!op.reduction_vars().empty()) { 759 p << " reduction("; 760 for (unsigned i = 0, e = op.getNumReductionVars(); i < e; ++i) { 761 if (i != 0) 762 p << ", "; 763 p << (*op.reductions())[i] << " -> " << op.reduction_vars()[i] << " : " 764 << op.reduction_vars()[i].getType(); 765 } 766 p << ")"; 767 } 768 769 if (op.inclusive()) { 770 p << " inclusive"; 771 } 772 773 p.printRegion(op.region(), /*printEntryBlockArgs=*/false); 774 } 775 776 //===----------------------------------------------------------------------===// 777 // ReductionOp 778 //===----------------------------------------------------------------------===// 779 780 static ParseResult parseAtomicReductionRegion(OpAsmParser &parser, 781 Region ®ion) { 782 if (parser.parseOptionalKeyword("atomic")) 783 return success(); 784 return parser.parseRegion(region); 785 } 786 787 static void printAtomicReductionRegion(OpAsmPrinter &printer, 788 ReductionDeclareOp op, Region ®ion) { 789 if (region.empty()) 790 return; 791 printer << "atomic "; 792 printer.printRegion(region); 793 } 794 795 static LogicalResult verifyReductionDeclareOp(ReductionDeclareOp op) { 796 if (op.initializerRegion().empty()) 797 return op.emitOpError() << "expects non-empty initializer region"; 798 Block &initializerEntryBlock = op.initializerRegion().front(); 799 if (initializerEntryBlock.getNumArguments() != 1 || 800 initializerEntryBlock.getArgument(0).getType() != op.type()) { 801 return op.emitOpError() << "expects initializer region with one argument " 802 "of the reduction type"; 803 } 804 805 for (YieldOp yieldOp : op.initializerRegion().getOps<YieldOp>()) { 806 if (yieldOp.results().size() != 1 || 807 yieldOp.results().getTypes()[0] != op.type()) 808 return op.emitOpError() << "expects initializer region to yield a value " 809 "of the reduction type"; 810 } 811 812 if (op.reductionRegion().empty()) 813 return op.emitOpError() << "expects non-empty reduction region"; 814 Block &reductionEntryBlock = op.reductionRegion().front(); 815 if (reductionEntryBlock.getNumArguments() != 2 || 816 reductionEntryBlock.getArgumentTypes()[0] != 817 reductionEntryBlock.getArgumentTypes()[1] || 818 reductionEntryBlock.getArgumentTypes()[0] != op.type()) 819 return op.emitOpError() << "expects reduction region with two arguments of " 820 "the reduction type"; 821 for (YieldOp yieldOp : op.reductionRegion().getOps<YieldOp>()) { 822 if (yieldOp.results().size() != 1 || 823 yieldOp.results().getTypes()[0] != op.type()) 824 return op.emitOpError() << "expects reduction region to yield a value " 825 "of the reduction type"; 826 } 827 828 if (op.atomicReductionRegion().empty()) 829 return success(); 830 831 Block &atomicReductionEntryBlock = op.atomicReductionRegion().front(); 832 if (atomicReductionEntryBlock.getNumArguments() != 2 || 833 atomicReductionEntryBlock.getArgumentTypes()[0] != 834 atomicReductionEntryBlock.getArgumentTypes()[1]) 835 return op.emitOpError() << "expects atomic reduction region with two " 836 "arguments of the same type"; 837 auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0] 838 .dyn_cast<PointerLikeType>(); 839 if (!ptrType || ptrType.getElementType() != op.type()) 840 return op.emitOpError() << "expects atomic reduction region arguments to " 841 "be accumulators containing the reduction type"; 842 return success(); 843 } 844 845 static LogicalResult verifyReductionOp(ReductionOp op) { 846 // TODO: generalize this to an op interface when there is more than one op 847 // that supports reductions. 848 auto container = op->getParentOfType<WsLoopOp>(); 849 for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i) 850 if (container.reduction_vars()[i] == op.accumulator()) 851 return success(); 852 853 return op.emitOpError() << "the accumulator is not used by the parent"; 854 } 855 856 //===----------------------------------------------------------------------===// 857 // WsLoopOp 858 //===----------------------------------------------------------------------===// 859 860 void WsLoopOp::build(OpBuilder &builder, OperationState &state, 861 ValueRange lowerBound, ValueRange upperBound, 862 ValueRange step, ArrayRef<NamedAttribute> attributes) { 863 build(builder, state, TypeRange(), lowerBound, upperBound, step, 864 /*private_vars=*/ValueRange(), 865 /*firstprivate_vars=*/ValueRange(), /*lastprivate_vars=*/ValueRange(), 866 /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(), 867 /*reduction_vars=*/ValueRange(), /*schedule_val=*/nullptr, 868 /*schedule_chunk_var=*/nullptr, /*collapse_val=*/nullptr, 869 /*nowait=*/nullptr, /*ordered_val=*/nullptr, /*order_val=*/nullptr, 870 /*inclusive=*/nullptr, /*buildBody=*/false); 871 state.addAttributes(attributes); 872 } 873 874 void WsLoopOp::build(OpBuilder &, OperationState &state, TypeRange resultTypes, 875 ValueRange operands, ArrayRef<NamedAttribute> attributes) { 876 state.addOperands(operands); 877 state.addAttributes(attributes); 878 (void)state.addRegion(); 879 assert(resultTypes.empty() && "mismatched number of return types"); 880 state.addTypes(resultTypes); 881 } 882 883 void WsLoopOp::build(OpBuilder &builder, OperationState &result, 884 TypeRange typeRange, ValueRange lowerBounds, 885 ValueRange upperBounds, ValueRange steps, 886 ValueRange privateVars, ValueRange firstprivateVars, 887 ValueRange lastprivateVars, ValueRange linearVars, 888 ValueRange linearStepVars, ValueRange reductionVars, 889 StringAttr scheduleVal, Value scheduleChunkVar, 890 IntegerAttr collapseVal, UnitAttr nowait, 891 IntegerAttr orderedVal, StringAttr orderVal, 892 UnitAttr inclusive, bool buildBody) { 893 result.addOperands(lowerBounds); 894 result.addOperands(upperBounds); 895 result.addOperands(steps); 896 result.addOperands(privateVars); 897 result.addOperands(firstprivateVars); 898 result.addOperands(linearVars); 899 result.addOperands(linearStepVars); 900 if (scheduleChunkVar) 901 result.addOperands(scheduleChunkVar); 902 903 if (scheduleVal) 904 result.addAttribute("schedule_val", scheduleVal); 905 if (collapseVal) 906 result.addAttribute("collapse_val", collapseVal); 907 if (nowait) 908 result.addAttribute("nowait", nowait); 909 if (orderedVal) 910 result.addAttribute("ordered_val", orderedVal); 911 if (orderVal) 912 result.addAttribute("order", orderVal); 913 if (inclusive) 914 result.addAttribute("inclusive", inclusive); 915 result.addAttribute( 916 WsLoopOp::getOperandSegmentSizeAttr(), 917 builder.getI32VectorAttr( 918 {static_cast<int32_t>(lowerBounds.size()), 919 static_cast<int32_t>(upperBounds.size()), 920 static_cast<int32_t>(steps.size()), 921 static_cast<int32_t>(privateVars.size()), 922 static_cast<int32_t>(firstprivateVars.size()), 923 static_cast<int32_t>(lastprivateVars.size()), 924 static_cast<int32_t>(linearVars.size()), 925 static_cast<int32_t>(linearStepVars.size()), 926 static_cast<int32_t>(reductionVars.size()), 927 static_cast<int32_t>(scheduleChunkVar != nullptr ? 1 : 0)})); 928 929 Region *bodyRegion = result.addRegion(); 930 if (buildBody) { 931 OpBuilder::InsertionGuard guard(builder); 932 unsigned numIVs = steps.size(); 933 SmallVector<Type, 8> argTypes(numIVs, steps.getType().front()); 934 builder.createBlock(bodyRegion, {}, argTypes); 935 } 936 } 937 938 static LogicalResult verifyWsLoopOp(WsLoopOp op) { 939 if (op.getNumReductionVars() != 0) { 940 if (!op.reductions() || 941 op.reductions()->size() != op.getNumReductionVars()) { 942 return op.emitOpError() << "expected as many reduction symbol references " 943 "as reduction variables"; 944 } 945 } else { 946 if (op.reductions()) 947 return op.emitOpError() << "unexpected reduction symbol references"; 948 return success(); 949 } 950 951 DenseSet<Value> accumulators; 952 for (auto args : llvm::zip(op.reduction_vars(), *op.reductions())) { 953 Value accum = std::get<0>(args); 954 if (!accumulators.insert(accum).second) { 955 return op.emitOpError() << "accumulator variable used more than once"; 956 } 957 Type varType = accum.getType().cast<PointerLikeType>(); 958 auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>(); 959 auto decl = 960 SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef); 961 if (!decl) { 962 return op.emitOpError() << "expected symbol reference " << symbolRef 963 << " to point to a reduction declaration"; 964 } 965 966 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType) { 967 return op.emitOpError() 968 << "expected accumulator (" << varType 969 << ") to be the same type as reduction declaration (" 970 << decl.getAccumulatorType() << ")"; 971 } 972 } 973 974 return success(); 975 } 976 977 #define GET_OP_CLASSES 978 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 979