1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the OpenMP dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/IR/Attributes.h" 16 #include "mlir/IR/OpImplementation.h" 17 #include "mlir/IR/OperationSupport.h" 18 19 #include "llvm/ADT/SmallString.h" 20 #include "llvm/ADT/StringRef.h" 21 #include "llvm/ADT/StringSwitch.h" 22 #include <cstddef> 23 24 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc" 25 26 using namespace mlir; 27 using namespace mlir::omp; 28 29 void OpenMPDialect::initialize() { 30 addOperations< 31 #define GET_OP_LIST 32 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 33 >(); 34 } 35 36 //===----------------------------------------------------------------------===// 37 // ParallelOp 38 //===----------------------------------------------------------------------===// 39 40 void ParallelOp::build(OpBuilder &builder, OperationState &state, 41 ArrayRef<NamedAttribute> attributes) { 42 ParallelOp::build( 43 builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr, 44 /*default_val=*/nullptr, /*private_vars=*/ValueRange(), 45 /*firstprivate_vars=*/ValueRange(), /*shared_vars=*/ValueRange(), 46 /*copyin_vars=*/ValueRange(), /*allocate_vars=*/ValueRange(), 47 /*allocators_vars=*/ValueRange(), /*proc_bind_val=*/nullptr); 48 state.addAttributes(attributes); 49 } 50 51 /// Parse a list of operands with types. 52 /// 53 /// operand-and-type-list ::= `(` ssa-id-and-type-list `)` 54 /// ssa-id-and-type-list ::= ssa-id-and-type | 55 /// ssa-id-and-type `,` ssa-id-and-type-list 56 /// ssa-id-and-type ::= ssa-id `:` type 57 static ParseResult 58 parseOperandAndTypeList(OpAsmParser &parser, 59 SmallVectorImpl<OpAsmParser::OperandType> &operands, 60 SmallVectorImpl<Type> &types) { 61 if (parser.parseLParen()) 62 return failure(); 63 64 do { 65 OpAsmParser::OperandType operand; 66 Type type; 67 if (parser.parseOperand(operand) || parser.parseColonType(type)) 68 return failure(); 69 operands.push_back(operand); 70 types.push_back(type); 71 } while (succeeded(parser.parseOptionalComma())); 72 73 if (parser.parseRParen()) 74 return failure(); 75 76 return success(); 77 } 78 79 /// Parse an allocate clause with allocators and a list of operands with types. 80 /// 81 /// operand-and-type-list ::= `(` allocate-operand-list `)` 82 /// allocate-operand-list :: = allocate-operand | 83 /// allocator-operand `,` allocate-operand-list 84 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type 85 /// ssa-id-and-type ::= ssa-id `:` type 86 static ParseResult parseAllocateAndAllocator( 87 OpAsmParser &parser, 88 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate, 89 SmallVectorImpl<Type> &typesAllocate, 90 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator, 91 SmallVectorImpl<Type> &typesAllocator) { 92 if (parser.parseLParen()) 93 return failure(); 94 95 do { 96 OpAsmParser::OperandType operand; 97 Type type; 98 99 if (parser.parseOperand(operand) || parser.parseColonType(type)) 100 return failure(); 101 operandsAllocator.push_back(operand); 102 typesAllocator.push_back(type); 103 if (parser.parseArrow()) 104 return failure(); 105 if (parser.parseOperand(operand) || parser.parseColonType(type)) 106 return failure(); 107 108 operandsAllocate.push_back(operand); 109 typesAllocate.push_back(type); 110 } while (succeeded(parser.parseOptionalComma())); 111 112 if (parser.parseRParen()) 113 return failure(); 114 115 return success(); 116 } 117 118 static LogicalResult verifyParallelOp(ParallelOp op) { 119 if (op.allocate_vars().size() != op.allocators_vars().size()) 120 return op.emitError( 121 "expected equal sizes for allocate and allocator variables"); 122 return success(); 123 } 124 125 static void printParallelOp(OpAsmPrinter &p, ParallelOp op) { 126 p << "omp.parallel"; 127 128 if (auto ifCond = op.if_expr_var()) 129 p << " if(" << ifCond << " : " << ifCond.getType() << ")"; 130 131 if (auto threads = op.num_threads_var()) 132 p << " num_threads(" << threads << " : " << threads.getType() << ")"; 133 134 // Print private, firstprivate, shared and copyin parameters 135 auto printDataVars = [&p](StringRef name, OperandRange vars) { 136 if (vars.size()) { 137 p << " " << name << "("; 138 for (unsigned i = 0; i < vars.size(); ++i) { 139 std::string separator = i == vars.size() - 1 ? ")" : ", "; 140 p << vars[i] << " : " << vars[i].getType() << separator; 141 } 142 } 143 }; 144 145 // Print allocator and allocate parameters 146 auto printAllocateAndAllocator = [&p](OperandRange varsAllocate, 147 OperandRange varsAllocator) { 148 if (varsAllocate.empty()) 149 return; 150 151 p << " allocate("; 152 for (unsigned i = 0; i < varsAllocate.size(); ++i) { 153 std::string separator = i == varsAllocate.size() - 1 ? ")" : ", "; 154 p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> "; 155 p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator; 156 } 157 }; 158 159 printDataVars("private", op.private_vars()); 160 printDataVars("firstprivate", op.firstprivate_vars()); 161 printDataVars("shared", op.shared_vars()); 162 printDataVars("copyin", op.copyin_vars()); 163 printAllocateAndAllocator(op.allocate_vars(), op.allocators_vars()); 164 165 if (auto def = op.default_val()) 166 p << " default(" << def->drop_front(3) << ")"; 167 168 if (auto bind = op.proc_bind_val()) 169 p << " proc_bind(" << bind << ")"; 170 171 p.printRegion(op.getRegion()); 172 } 173 174 /// Emit an error if the same clause is present more than once on an operation. 175 static ParseResult allowedOnce(OpAsmParser &parser, llvm::StringRef clause, 176 llvm::StringRef operation) { 177 return parser.emitError(parser.getNameLoc()) 178 << " at most one " << clause << " clause can appear on the " 179 << operation << " operation"; 180 } 181 182 /// Parses a parallel operation. 183 /// 184 /// operation ::= `omp.parallel` clause-list 185 /// clause-list ::= clause | clause clause-list 186 /// clause ::= if | numThreads | private | firstprivate | shared | copyin | 187 /// default | procBind 188 /// if ::= `if` `(` ssa-id `)` 189 /// numThreads ::= `num_threads` `(` ssa-id-and-type `)` 190 /// private ::= `private` operand-and-type-list 191 /// firstprivate ::= `firstprivate` operand-and-type-list 192 /// shared ::= `shared` operand-and-type-list 193 /// copyin ::= `copyin` operand-and-type-list 194 /// allocate ::= `allocate` operand-and-type `->` operand-and-type-list 195 /// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`) 196 /// procBind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)` 197 /// 198 /// Note that each clause can only appear once in the clase-list. 199 static ParseResult parseParallelOp(OpAsmParser &parser, 200 OperationState &result) { 201 std::pair<OpAsmParser::OperandType, Type> ifCond; 202 std::pair<OpAsmParser::OperandType, Type> numThreads; 203 SmallVector<OpAsmParser::OperandType, 4> privates; 204 SmallVector<Type, 4> privateTypes; 205 SmallVector<OpAsmParser::OperandType, 4> firstprivates; 206 SmallVector<Type, 4> firstprivateTypes; 207 SmallVector<OpAsmParser::OperandType, 4> shareds; 208 SmallVector<Type, 4> sharedTypes; 209 SmallVector<OpAsmParser::OperandType, 4> copyins; 210 SmallVector<Type, 4> copyinTypes; 211 SmallVector<OpAsmParser::OperandType, 4> allocates; 212 SmallVector<Type, 4> allocateTypes; 213 SmallVector<OpAsmParser::OperandType, 4> allocators; 214 SmallVector<Type, 4> allocatorTypes; 215 std::array<int, 8> segments{0, 0, 0, 0, 0, 0, 0, 0}; 216 llvm::StringRef keyword; 217 bool defaultVal = false; 218 bool procBind = false; 219 220 const int ifClausePos = 0; 221 const int numThreadsClausePos = 1; 222 const int privateClausePos = 2; 223 const int firstprivateClausePos = 3; 224 const int sharedClausePos = 4; 225 const int copyinClausePos = 5; 226 const int allocateClausePos = 6; 227 const int allocatorPos = 7; 228 const llvm::StringRef opName = result.name.getStringRef(); 229 230 while (succeeded(parser.parseOptionalKeyword(&keyword))) { 231 if (keyword == "if") { 232 // Fail if there was already another if condition 233 if (segments[ifClausePos]) 234 return allowedOnce(parser, "if", opName); 235 if (parser.parseLParen() || parser.parseOperand(ifCond.first) || 236 parser.parseColonType(ifCond.second) || parser.parseRParen()) 237 return failure(); 238 segments[ifClausePos] = 1; 239 } else if (keyword == "num_threads") { 240 // fail if there was already another num_threads clause 241 if (segments[numThreadsClausePos]) 242 return allowedOnce(parser, "num_threads", opName); 243 if (parser.parseLParen() || parser.parseOperand(numThreads.first) || 244 parser.parseColonType(numThreads.second) || parser.parseRParen()) 245 return failure(); 246 segments[numThreadsClausePos] = 1; 247 } else if (keyword == "private") { 248 // fail if there was already another private clause 249 if (segments[privateClausePos]) 250 return allowedOnce(parser, "private", opName); 251 if (parseOperandAndTypeList(parser, privates, privateTypes)) 252 return failure(); 253 segments[privateClausePos] = privates.size(); 254 } else if (keyword == "firstprivate") { 255 // fail if there was already another firstprivate clause 256 if (segments[firstprivateClausePos]) 257 return allowedOnce(parser, "firstprivate", opName); 258 if (parseOperandAndTypeList(parser, firstprivates, firstprivateTypes)) 259 return failure(); 260 segments[firstprivateClausePos] = firstprivates.size(); 261 } else if (keyword == "shared") { 262 // fail if there was already another shared clause 263 if (segments[sharedClausePos]) 264 return allowedOnce(parser, "shared", opName); 265 if (parseOperandAndTypeList(parser, shareds, sharedTypes)) 266 return failure(); 267 segments[sharedClausePos] = shareds.size(); 268 } else if (keyword == "copyin") { 269 // fail if there was already another copyin clause 270 if (segments[copyinClausePos]) 271 return allowedOnce(parser, "copyin", opName); 272 if (parseOperandAndTypeList(parser, copyins, copyinTypes)) 273 return failure(); 274 segments[copyinClausePos] = copyins.size(); 275 } else if (keyword == "allocate") { 276 // fail if there was already another allocate clause 277 if (segments[allocateClausePos]) 278 return allowedOnce(parser, "allocate", opName); 279 if (parseAllocateAndAllocator(parser, allocates, allocateTypes, 280 allocators, allocatorTypes)) 281 return failure(); 282 segments[allocateClausePos] = allocates.size(); 283 segments[allocatorPos] = allocators.size(); 284 } else if (keyword == "default") { 285 // fail if there was already another default clause 286 if (defaultVal) 287 return allowedOnce(parser, "default", opName); 288 defaultVal = true; 289 llvm::StringRef defval; 290 if (parser.parseLParen() || parser.parseKeyword(&defval) || 291 parser.parseRParen()) 292 return failure(); 293 llvm::SmallString<16> attrval; 294 // The def prefix is required for the attribute as "private" is a keyword 295 // in C++ 296 attrval += "def"; 297 attrval += defval; 298 auto attr = parser.getBuilder().getStringAttr(attrval); 299 result.addAttribute("default_val", attr); 300 } else if (keyword == "proc_bind") { 301 // fail if there was already another proc_bind clause 302 if (procBind) 303 return allowedOnce(parser, "proc_bind", opName); 304 procBind = true; 305 llvm::StringRef bind; 306 if (parser.parseLParen() || parser.parseKeyword(&bind) || 307 parser.parseRParen()) 308 return failure(); 309 auto attr = parser.getBuilder().getStringAttr(bind); 310 result.addAttribute("proc_bind_val", attr); 311 } else { 312 return parser.emitError(parser.getNameLoc()) 313 << keyword << " is not a valid clause for the " << opName 314 << " operation"; 315 } 316 } 317 318 // Add if parameter 319 if (segments[ifClausePos] && 320 parser.resolveOperand(ifCond.first, ifCond.second, result.operands)) 321 return failure(); 322 323 // Add num_threads parameter 324 if (segments[numThreadsClausePos] && 325 parser.resolveOperand(numThreads.first, numThreads.second, 326 result.operands)) 327 return failure(); 328 329 // Add private parameters 330 if (segments[privateClausePos] && 331 parser.resolveOperands(privates, privateTypes, privates[0].location, 332 result.operands)) 333 return failure(); 334 335 // Add firstprivate parameters 336 if (segments[firstprivateClausePos] && 337 parser.resolveOperands(firstprivates, firstprivateTypes, 338 firstprivates[0].location, result.operands)) 339 return failure(); 340 341 // Add shared parameters 342 if (segments[sharedClausePos] && 343 parser.resolveOperands(shareds, sharedTypes, shareds[0].location, 344 result.operands)) 345 return failure(); 346 347 // Add copyin parameters 348 if (segments[copyinClausePos] && 349 parser.resolveOperands(copyins, copyinTypes, copyins[0].location, 350 result.operands)) 351 return failure(); 352 353 // Add allocate parameters 354 if (segments[allocateClausePos] && 355 parser.resolveOperands(allocates, allocateTypes, allocates[0].location, 356 result.operands)) 357 return failure(); 358 359 // Add allocator parameters 360 if (segments[allocatorPos] && 361 parser.resolveOperands(allocators, allocatorTypes, allocators[0].location, 362 result.operands)) 363 return failure(); 364 365 result.addAttribute("operand_segment_sizes", 366 parser.getBuilder().getI32VectorAttr(segments)); 367 368 Region *body = result.addRegion(); 369 SmallVector<OpAsmParser::OperandType, 4> regionArgs; 370 SmallVector<Type, 4> regionArgTypes; 371 if (parser.parseRegion(*body, regionArgs, regionArgTypes)) 372 return failure(); 373 return success(); 374 } 375 376 //===----------------------------------------------------------------------===// 377 // WsLoopOp 378 //===----------------------------------------------------------------------===// 379 380 void WsLoopOp::build(OpBuilder &builder, OperationState &state, 381 ValueRange lowerBound, ValueRange upperBound, 382 ValueRange step, ArrayRef<NamedAttribute> attributes) { 383 build(builder, state, TypeRange(), lowerBound, upperBound, step, 384 /*private_vars=*/ValueRange(), 385 /*firstprivate_vars=*/ValueRange(), /*lastprivate_vars=*/ValueRange(), 386 /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(), 387 /*schedule_val=*/nullptr, /*schedule_chunk_var=*/nullptr, 388 /*collapse_val=*/nullptr, 389 /*nowait=*/false, /*ordered_val=*/nullptr, /*order_val=*/nullptr, 390 /*inclusive=*/false); 391 state.addAttributes(attributes); 392 } 393 394 #define GET_OP_CLASSES 395 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 396