1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the OpenMP dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/IR/Attributes.h" 16 #include "mlir/IR/OpImplementation.h" 17 #include "mlir/IR/OperationSupport.h" 18 19 #include "llvm/ADT/SmallString.h" 20 #include "llvm/ADT/StringRef.h" 21 #include "llvm/ADT/StringSwitch.h" 22 #include <cstddef> 23 24 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc" 25 26 using namespace mlir; 27 using namespace mlir::omp; 28 29 void OpenMPDialect::initialize() { 30 addOperations< 31 #define GET_OP_LIST 32 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 33 >(); 34 } 35 36 //===----------------------------------------------------------------------===// 37 // ParallelOp 38 //===----------------------------------------------------------------------===// 39 40 /// Parse a list of operands with types. 41 /// 42 /// operand-and-type-list ::= `(` ssa-id-and-type-list `)` 43 /// ssa-id-and-type-list ::= ssa-id-and-type | 44 /// ssa-id-and-type ',' ssa-id-and-type-list 45 /// ssa-id-and-type ::= ssa-id `:` type 46 static ParseResult 47 parseOperandAndTypeList(OpAsmParser &parser, 48 SmallVectorImpl<OpAsmParser::OperandType> &operands, 49 SmallVectorImpl<Type> &types) { 50 if (parser.parseLParen()) 51 return failure(); 52 53 do { 54 OpAsmParser::OperandType operand; 55 Type type; 56 if (parser.parseOperand(operand) || parser.parseColonType(type)) 57 return failure(); 58 operands.push_back(operand); 59 types.push_back(type); 60 } while (succeeded(parser.parseOptionalComma())); 61 62 if (parser.parseRParen()) 63 return failure(); 64 65 return success(); 66 } 67 68 static void printParallelOp(OpAsmPrinter &p, ParallelOp op) { 69 p << "omp.parallel"; 70 71 if (auto ifCond = op.if_expr_var()) 72 p << " if(" << ifCond << " : " << ifCond.getType() << ")"; 73 74 if (auto threads = op.num_threads_var()) 75 p << " num_threads(" << threads << " : " << threads.getType() << ")"; 76 77 // Print private, firstprivate, shared and copyin parameters 78 auto printDataVars = [&p](StringRef name, OperandRange vars) { 79 if (vars.size()) { 80 p << " " << name << "("; 81 for (unsigned i = 0; i < vars.size(); ++i) { 82 std::string separator = i == vars.size() - 1 ? ")" : ", "; 83 p << vars[i] << " : " << vars[i].getType() << separator; 84 } 85 } 86 }; 87 printDataVars("private", op.private_vars()); 88 printDataVars("firstprivate", op.firstprivate_vars()); 89 printDataVars("shared", op.shared_vars()); 90 printDataVars("copyin", op.copyin_vars()); 91 92 if (auto def = op.default_val()) 93 p << " default(" << def->drop_front(3) << ")"; 94 95 if (auto bind = op.proc_bind_val()) 96 p << " proc_bind(" << bind << ")"; 97 98 p.printRegion(op.getRegion()); 99 } 100 101 /// Emit an error if the same clause is present more than once on an operation. 102 static ParseResult allowedOnce(OpAsmParser &parser, llvm::StringRef clause, 103 llvm::StringRef operation) { 104 return parser.emitError(parser.getNameLoc()) 105 << " at most one " << clause << " clause can appear on the " 106 << operation << " operation"; 107 } 108 109 /// Parses a parallel operation. 110 /// 111 /// operation ::= `omp.parallel` clause-list 112 /// clause-list ::= clause | clause clause-list 113 /// clause ::= if | numThreads | private | firstprivate | shared | copyin | 114 /// default | procBind 115 /// if ::= `if` `(` ssa-id `)` 116 /// numThreads ::= `num_threads` `(` ssa-id-and-type `)` 117 /// private ::= `private` operand-and-type-list 118 /// firstprivate ::= `firstprivate` operand-and-type-list 119 /// shared ::= `shared` operand-and-type-list 120 /// copyin ::= `copyin` operand-and-type-list 121 /// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`) 122 /// procBind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)` 123 /// 124 /// Note that each clause can only appear once in the clase-list. 125 static ParseResult parseParallelOp(OpAsmParser &parser, 126 OperationState &result) { 127 std::pair<OpAsmParser::OperandType, Type> ifCond; 128 std::pair<OpAsmParser::OperandType, Type> numThreads; 129 llvm::SmallVector<OpAsmParser::OperandType, 4> privates; 130 llvm::SmallVector<Type, 4> privateTypes; 131 llvm::SmallVector<OpAsmParser::OperandType, 4> firstprivates; 132 llvm::SmallVector<Type, 4> firstprivateTypes; 133 llvm::SmallVector<OpAsmParser::OperandType, 4> shareds; 134 llvm::SmallVector<Type, 4> sharedTypes; 135 llvm::SmallVector<OpAsmParser::OperandType, 4> copyins; 136 llvm::SmallVector<Type, 4> copyinTypes; 137 std::array<int, 6> segments{0, 0, 0, 0, 0, 0}; 138 llvm::StringRef keyword; 139 bool defaultVal = false; 140 bool procBind = false; 141 142 const int ifClausePos = 0; 143 const int numThreadsClausePos = 1; 144 const int privateClausePos = 2; 145 const int firstprivateClausePos = 3; 146 const int sharedClausePos = 4; 147 const int copyinClausePos = 5; 148 const llvm::StringRef opName = result.name.getStringRef(); 149 150 while (succeeded(parser.parseOptionalKeyword(&keyword))) { 151 if (keyword == "if") { 152 // Fail if there was already another if condition 153 if (segments[ifClausePos]) 154 return allowedOnce(parser, "if", opName); 155 if (parser.parseLParen() || parser.parseOperand(ifCond.first) || 156 parser.parseColonType(ifCond.second) || parser.parseRParen()) 157 return failure(); 158 segments[ifClausePos] = 1; 159 } else if (keyword == "num_threads") { 160 // fail if there was already another num_threads clause 161 if (segments[numThreadsClausePos]) 162 return allowedOnce(parser, "num_threads", opName); 163 if (parser.parseLParen() || parser.parseOperand(numThreads.first) || 164 parser.parseColonType(numThreads.second) || parser.parseRParen()) 165 return failure(); 166 segments[numThreadsClausePos] = 1; 167 } else if (keyword == "private") { 168 // fail if there was already another private clause 169 if (segments[privateClausePos]) 170 return allowedOnce(parser, "private", opName); 171 if (parseOperandAndTypeList(parser, privates, privateTypes)) 172 return failure(); 173 segments[privateClausePos] = privates.size(); 174 } else if (keyword == "firstprivate") { 175 // fail if there was already another firstprivate clause 176 if (segments[firstprivateClausePos]) 177 return allowedOnce(parser, "firstprivate", opName); 178 if (parseOperandAndTypeList(parser, firstprivates, firstprivateTypes)) 179 return failure(); 180 segments[firstprivateClausePos] = firstprivates.size(); 181 } else if (keyword == "shared") { 182 // fail if there was already another shared clause 183 if (segments[sharedClausePos]) 184 return allowedOnce(parser, "shared", opName); 185 if (parseOperandAndTypeList(parser, shareds, sharedTypes)) 186 return failure(); 187 segments[sharedClausePos] = shareds.size(); 188 } else if (keyword == "copyin") { 189 // fail if there was already another copyin clause 190 if (segments[copyinClausePos]) 191 return allowedOnce(parser, "copyin", opName); 192 if (parseOperandAndTypeList(parser, copyins, copyinTypes)) 193 return failure(); 194 segments[copyinClausePos] = copyins.size(); 195 } else if (keyword == "default") { 196 // fail if there was already another default clause 197 if (defaultVal) 198 return allowedOnce(parser, "default", opName); 199 defaultVal = true; 200 llvm::StringRef defval; 201 if (parser.parseLParen() || parser.parseKeyword(&defval) || 202 parser.parseRParen()) 203 return failure(); 204 llvm::SmallString<16> attrval; 205 // The def prefix is required for the attribute as "private" is a keyword 206 // in C++ 207 attrval += "def"; 208 attrval += defval; 209 auto attr = parser.getBuilder().getStringAttr(attrval); 210 result.addAttribute("default_val", attr); 211 } else if (keyword == "proc_bind") { 212 // fail if there was already another proc_bind clause 213 if (procBind) 214 return allowedOnce(parser, "proc_bind", opName); 215 procBind = true; 216 llvm::StringRef bind; 217 if (parser.parseLParen() || parser.parseKeyword(&bind) || 218 parser.parseRParen()) 219 return failure(); 220 auto attr = parser.getBuilder().getStringAttr(bind); 221 result.addAttribute("proc_bind_val", attr); 222 } else { 223 return parser.emitError(parser.getNameLoc()) 224 << keyword << " is not a valid clause for the " << opName 225 << " operation"; 226 } 227 } 228 229 // Add if parameter 230 if (segments[ifClausePos]) { 231 parser.resolveOperand(ifCond.first, ifCond.second, result.operands); 232 } 233 234 // Add num_threads parameter 235 if (segments[numThreadsClausePos]) { 236 parser.resolveOperand(numThreads.first, numThreads.second, result.operands); 237 } 238 239 // Add private parameters 240 if (segments[privateClausePos]) { 241 parser.resolveOperands(privates, privateTypes, privates[0].location, 242 result.operands); 243 } 244 245 // Add firstprivate parameters 246 if (segments[firstprivateClausePos]) { 247 parser.resolveOperands(firstprivates, firstprivateTypes, 248 firstprivates[0].location, result.operands); 249 } 250 251 // Add shared parameters 252 if (segments[sharedClausePos]) { 253 parser.resolveOperands(shareds, sharedTypes, shareds[0].location, 254 result.operands); 255 } 256 257 // Add copyin parameters 258 if (segments[copyinClausePos]) { 259 parser.resolveOperands(copyins, copyinTypes, copyins[0].location, 260 result.operands); 261 } 262 263 result.addAttribute("operand_segment_sizes", 264 parser.getBuilder().getI32VectorAttr(segments)); 265 266 Region *body = result.addRegion(); 267 llvm::SmallVector<OpAsmParser::OperandType, 4> regionArgs; 268 llvm::SmallVector<Type, 4> regionArgTypes; 269 if (parser.parseRegion(*body, regionArgs, regionArgTypes)) 270 return failure(); 271 return success(); 272 } 273 274 namespace mlir { 275 namespace omp { 276 #define GET_OP_CLASSES 277 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 278 } // namespace omp 279 } // namespace mlir 280