1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the OpenMP dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/IR/Attributes.h" 16 #include "mlir/IR/OpImplementation.h" 17 #include "mlir/IR/OperationSupport.h" 18 19 #include "llvm/ADT/SmallString.h" 20 #include "llvm/ADT/StringRef.h" 21 #include "llvm/ADT/StringSwitch.h" 22 #include <cstddef> 23 24 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc" 25 26 using namespace mlir; 27 using namespace mlir::omp; 28 29 void OpenMPDialect::initialize() { 30 addOperations< 31 #define GET_OP_LIST 32 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 33 >(); 34 } 35 36 //===----------------------------------------------------------------------===// 37 // ParallelOp 38 //===----------------------------------------------------------------------===// 39 40 /// Parse a list of operands with types. 41 /// 42 /// operand-and-type-list ::= `(` ssa-id-and-type-list `)` 43 /// ssa-id-and-type-list ::= ssa-id-and-type | 44 /// ssa-id-and-type `,` ssa-id-and-type-list 45 /// ssa-id-and-type ::= ssa-id `:` type 46 static ParseResult 47 parseOperandAndTypeList(OpAsmParser &parser, 48 SmallVectorImpl<OpAsmParser::OperandType> &operands, 49 SmallVectorImpl<Type> &types) { 50 if (parser.parseLParen()) 51 return failure(); 52 53 do { 54 OpAsmParser::OperandType operand; 55 Type type; 56 if (parser.parseOperand(operand) || parser.parseColonType(type)) 57 return failure(); 58 operands.push_back(operand); 59 types.push_back(type); 60 } while (succeeded(parser.parseOptionalComma())); 61 62 if (parser.parseRParen()) 63 return failure(); 64 65 return success(); 66 } 67 68 /// Parse an allocate clause with allocators and a list of operands with types. 69 /// 70 /// operand-and-type-list ::= `(` allocate-operand-list `)` 71 /// allocate-operand-list :: = allocate-operand | 72 /// allocator-operand `,` allocate-operand-list 73 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type 74 /// ssa-id-and-type ::= ssa-id `:` type 75 static ParseResult parseAllocateAndAllocator( 76 OpAsmParser &parser, 77 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate, 78 SmallVectorImpl<Type> &typesAllocate, 79 SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator, 80 SmallVectorImpl<Type> &typesAllocator) { 81 if (parser.parseLParen()) 82 return failure(); 83 84 do { 85 OpAsmParser::OperandType operand; 86 Type type; 87 88 if (parser.parseOperand(operand) || parser.parseColonType(type)) 89 return failure(); 90 operandsAllocator.push_back(operand); 91 typesAllocator.push_back(type); 92 if (parser.parseArrow()) 93 return failure(); 94 if (parser.parseOperand(operand) || parser.parseColonType(type)) 95 return failure(); 96 97 operandsAllocate.push_back(operand); 98 typesAllocate.push_back(type); 99 } while (succeeded(parser.parseOptionalComma())); 100 101 if (parser.parseRParen()) 102 return failure(); 103 104 return success(); 105 } 106 107 static LogicalResult verifyParallelOp(ParallelOp op) { 108 if (op.allocate_vars().size() != op.allocators_vars().size()) 109 return op.emitError( 110 "expected equal sizes for allocate and allocator variables"); 111 return success(); 112 } 113 114 static void printParallelOp(OpAsmPrinter &p, ParallelOp op) { 115 p << "omp.parallel"; 116 117 if (auto ifCond = op.if_expr_var()) 118 p << " if(" << ifCond << " : " << ifCond.getType() << ")"; 119 120 if (auto threads = op.num_threads_var()) 121 p << " num_threads(" << threads << " : " << threads.getType() << ")"; 122 123 // Print private, firstprivate, shared and copyin parameters 124 auto printDataVars = [&p](StringRef name, OperandRange vars) { 125 if (vars.size()) { 126 p << " " << name << "("; 127 for (unsigned i = 0; i < vars.size(); ++i) { 128 std::string separator = i == vars.size() - 1 ? ")" : ", "; 129 p << vars[i] << " : " << vars[i].getType() << separator; 130 } 131 } 132 }; 133 134 // Print allocator and allocate parameters 135 auto printAllocateAndAllocator = [&p](OperandRange varsAllocate, 136 OperandRange varsAllocator) { 137 if (varsAllocate.empty()) 138 return; 139 140 p << " allocate("; 141 for (unsigned i = 0; i < varsAllocate.size(); ++i) { 142 std::string separator = i == varsAllocate.size() - 1 ? ")" : ", "; 143 p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> "; 144 p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator; 145 } 146 }; 147 148 printDataVars("private", op.private_vars()); 149 printDataVars("firstprivate", op.firstprivate_vars()); 150 printDataVars("shared", op.shared_vars()); 151 printDataVars("copyin", op.copyin_vars()); 152 printAllocateAndAllocator(op.allocate_vars(), op.allocators_vars()); 153 154 if (auto def = op.default_val()) 155 p << " default(" << def->drop_front(3) << ")"; 156 157 if (auto bind = op.proc_bind_val()) 158 p << " proc_bind(" << bind << ")"; 159 160 p.printRegion(op.getRegion()); 161 } 162 163 /// Emit an error if the same clause is present more than once on an operation. 164 static ParseResult allowedOnce(OpAsmParser &parser, llvm::StringRef clause, 165 llvm::StringRef operation) { 166 return parser.emitError(parser.getNameLoc()) 167 << " at most one " << clause << " clause can appear on the " 168 << operation << " operation"; 169 } 170 171 /// Parses a parallel operation. 172 /// 173 /// operation ::= `omp.parallel` clause-list 174 /// clause-list ::= clause | clause clause-list 175 /// clause ::= if | numThreads | private | firstprivate | shared | copyin | 176 /// default | procBind 177 /// if ::= `if` `(` ssa-id `)` 178 /// numThreads ::= `num_threads` `(` ssa-id-and-type `)` 179 /// private ::= `private` operand-and-type-list 180 /// firstprivate ::= `firstprivate` operand-and-type-list 181 /// shared ::= `shared` operand-and-type-list 182 /// copyin ::= `copyin` operand-and-type-list 183 /// allocate ::= `allocate` operand-and-type `->` operand-and-type-list 184 /// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`) 185 /// procBind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)` 186 /// 187 /// Note that each clause can only appear once in the clase-list. 188 static ParseResult parseParallelOp(OpAsmParser &parser, 189 OperationState &result) { 190 std::pair<OpAsmParser::OperandType, Type> ifCond; 191 std::pair<OpAsmParser::OperandType, Type> numThreads; 192 SmallVector<OpAsmParser::OperandType, 4> privates; 193 SmallVector<Type, 4> privateTypes; 194 SmallVector<OpAsmParser::OperandType, 4> firstprivates; 195 SmallVector<Type, 4> firstprivateTypes; 196 SmallVector<OpAsmParser::OperandType, 4> shareds; 197 SmallVector<Type, 4> sharedTypes; 198 SmallVector<OpAsmParser::OperandType, 4> copyins; 199 SmallVector<Type, 4> copyinTypes; 200 SmallVector<OpAsmParser::OperandType, 4> allocates; 201 SmallVector<Type, 4> allocateTypes; 202 SmallVector<OpAsmParser::OperandType, 4> allocators; 203 SmallVector<Type, 4> allocatorTypes; 204 std::array<int, 8> segments{0, 0, 0, 0, 0, 0, 0, 0}; 205 llvm::StringRef keyword; 206 bool defaultVal = false; 207 bool procBind = false; 208 209 const int ifClausePos = 0; 210 const int numThreadsClausePos = 1; 211 const int privateClausePos = 2; 212 const int firstprivateClausePos = 3; 213 const int sharedClausePos = 4; 214 const int copyinClausePos = 5; 215 const int allocateClausePos = 6; 216 const int allocatorPos = 7; 217 const llvm::StringRef opName = result.name.getStringRef(); 218 219 while (succeeded(parser.parseOptionalKeyword(&keyword))) { 220 if (keyword == "if") { 221 // Fail if there was already another if condition 222 if (segments[ifClausePos]) 223 return allowedOnce(parser, "if", opName); 224 if (parser.parseLParen() || parser.parseOperand(ifCond.first) || 225 parser.parseColonType(ifCond.second) || parser.parseRParen()) 226 return failure(); 227 segments[ifClausePos] = 1; 228 } else if (keyword == "num_threads") { 229 // fail if there was already another num_threads clause 230 if (segments[numThreadsClausePos]) 231 return allowedOnce(parser, "num_threads", opName); 232 if (parser.parseLParen() || parser.parseOperand(numThreads.first) || 233 parser.parseColonType(numThreads.second) || parser.parseRParen()) 234 return failure(); 235 segments[numThreadsClausePos] = 1; 236 } else if (keyword == "private") { 237 // fail if there was already another private clause 238 if (segments[privateClausePos]) 239 return allowedOnce(parser, "private", opName); 240 if (parseOperandAndTypeList(parser, privates, privateTypes)) 241 return failure(); 242 segments[privateClausePos] = privates.size(); 243 } else if (keyword == "firstprivate") { 244 // fail if there was already another firstprivate clause 245 if (segments[firstprivateClausePos]) 246 return allowedOnce(parser, "firstprivate", opName); 247 if (parseOperandAndTypeList(parser, firstprivates, firstprivateTypes)) 248 return failure(); 249 segments[firstprivateClausePos] = firstprivates.size(); 250 } else if (keyword == "shared") { 251 // fail if there was already another shared clause 252 if (segments[sharedClausePos]) 253 return allowedOnce(parser, "shared", opName); 254 if (parseOperandAndTypeList(parser, shareds, sharedTypes)) 255 return failure(); 256 segments[sharedClausePos] = shareds.size(); 257 } else if (keyword == "copyin") { 258 // fail if there was already another copyin clause 259 if (segments[copyinClausePos]) 260 return allowedOnce(parser, "copyin", opName); 261 if (parseOperandAndTypeList(parser, copyins, copyinTypes)) 262 return failure(); 263 segments[copyinClausePos] = copyins.size(); 264 } else if (keyword == "allocate") { 265 // fail if there was already another allocate clause 266 if (segments[allocateClausePos]) 267 return allowedOnce(parser, "allocate", opName); 268 if (parseAllocateAndAllocator(parser, allocates, allocateTypes, 269 allocators, allocatorTypes)) 270 return failure(); 271 segments[allocateClausePos] = allocates.size(); 272 segments[allocatorPos] = allocators.size(); 273 } else if (keyword == "default") { 274 // fail if there was already another default clause 275 if (defaultVal) 276 return allowedOnce(parser, "default", opName); 277 defaultVal = true; 278 llvm::StringRef defval; 279 if (parser.parseLParen() || parser.parseKeyword(&defval) || 280 parser.parseRParen()) 281 return failure(); 282 llvm::SmallString<16> attrval; 283 // The def prefix is required for the attribute as "private" is a keyword 284 // in C++ 285 attrval += "def"; 286 attrval += defval; 287 auto attr = parser.getBuilder().getStringAttr(attrval); 288 result.addAttribute("default_val", attr); 289 } else if (keyword == "proc_bind") { 290 // fail if there was already another proc_bind clause 291 if (procBind) 292 return allowedOnce(parser, "proc_bind", opName); 293 procBind = true; 294 llvm::StringRef bind; 295 if (parser.parseLParen() || parser.parseKeyword(&bind) || 296 parser.parseRParen()) 297 return failure(); 298 auto attr = parser.getBuilder().getStringAttr(bind); 299 result.addAttribute("proc_bind_val", attr); 300 } else { 301 return parser.emitError(parser.getNameLoc()) 302 << keyword << " is not a valid clause for the " << opName 303 << " operation"; 304 } 305 } 306 307 // Add if parameter 308 if (segments[ifClausePos] && 309 parser.resolveOperand(ifCond.first, ifCond.second, result.operands)) 310 return failure(); 311 312 // Add num_threads parameter 313 if (segments[numThreadsClausePos] && 314 parser.resolveOperand(numThreads.first, numThreads.second, 315 result.operands)) 316 return failure(); 317 318 // Add private parameters 319 if (segments[privateClausePos] && 320 parser.resolveOperands(privates, privateTypes, privates[0].location, 321 result.operands)) 322 return failure(); 323 324 // Add firstprivate parameters 325 if (segments[firstprivateClausePos] && 326 parser.resolveOperands(firstprivates, firstprivateTypes, 327 firstprivates[0].location, result.operands)) 328 return failure(); 329 330 // Add shared parameters 331 if (segments[sharedClausePos] && 332 parser.resolveOperands(shareds, sharedTypes, shareds[0].location, 333 result.operands)) 334 return failure(); 335 336 // Add copyin parameters 337 if (segments[copyinClausePos] && 338 parser.resolveOperands(copyins, copyinTypes, copyins[0].location, 339 result.operands)) 340 return failure(); 341 342 // Add allocate parameters 343 if (segments[allocateClausePos] && 344 parser.resolveOperands(allocates, allocateTypes, allocates[0].location, 345 result.operands)) 346 return failure(); 347 348 // Add allocator parameters 349 if (segments[allocatorPos] && 350 parser.resolveOperands(allocators, allocatorTypes, allocators[0].location, 351 result.operands)) 352 return failure(); 353 354 result.addAttribute("operand_segment_sizes", 355 parser.getBuilder().getI32VectorAttr(segments)); 356 357 Region *body = result.addRegion(); 358 SmallVector<OpAsmParser::OperandType, 4> regionArgs; 359 SmallVector<Type, 4> regionArgTypes; 360 if (parser.parseRegion(*body, regionArgs, regionArgTypes)) 361 return failure(); 362 return success(); 363 } 364 365 #define GET_OP_CLASSES 366 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 367