1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the OpenMP dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/IR/Attributes.h" 16 #include "mlir/IR/OpImplementation.h" 17 #include "mlir/IR/OperationSupport.h" 18 19 #include "llvm/ADT/SmallString.h" 20 #include "llvm/ADT/StringRef.h" 21 #include "llvm/ADT/StringSwitch.h" 22 #include <cstddef> 23 24 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc" 25 26 using namespace mlir; 27 using namespace mlir::omp; 28 29 OpenMPDialect::OpenMPDialect(MLIRContext *context) 30 : Dialect(getDialectNamespace(), context) { 31 addOperations< 32 #define GET_OP_LIST 33 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 34 >(); 35 } 36 37 //===----------------------------------------------------------------------===// 38 // ParallelOp 39 //===----------------------------------------------------------------------===// 40 41 /// Parse a list of operands with types. 42 /// 43 /// operand-and-type-list ::= `(` ssa-id-and-type-list `)` 44 /// ssa-id-and-type-list ::= ssa-id-and-type | 45 /// ssa-id-and-type ',' ssa-id-and-type-list 46 /// ssa-id-and-type ::= ssa-id `:` type 47 static ParseResult 48 parseOperandAndTypeList(OpAsmParser &parser, 49 SmallVectorImpl<OpAsmParser::OperandType> &operands, 50 SmallVectorImpl<Type> &types) { 51 if (parser.parseLParen()) 52 return failure(); 53 54 do { 55 OpAsmParser::OperandType operand; 56 Type type; 57 if (parser.parseOperand(operand) || parser.parseColonType(type)) 58 return failure(); 59 operands.push_back(operand); 60 types.push_back(type); 61 } while (succeeded(parser.parseOptionalComma())); 62 63 if (parser.parseRParen()) 64 return failure(); 65 66 return success(); 67 } 68 69 static void printParallelOp(OpAsmPrinter &p, ParallelOp op) { 70 p << "omp.parallel"; 71 72 if (auto ifCond = op.if_expr_var()) 73 p << " if(" << ifCond << ")"; 74 75 if (auto threads = op.num_threads_var()) 76 p << " num_threads(" << threads << " : " << threads.getType() << ")"; 77 78 // Print private, firstprivate, shared and copyin parameters 79 auto printDataVars = [&p](StringRef name, OperandRange vars) { 80 if (vars.size()) { 81 p << " " << name << "("; 82 for (unsigned i = 0; i < vars.size(); ++i) { 83 std::string separator = i == vars.size() - 1 ? ")" : ", "; 84 p << vars[i] << " : " << vars[i].getType() << separator; 85 } 86 } 87 }; 88 printDataVars("private", op.private_vars()); 89 printDataVars("firstprivate", op.firstprivate_vars()); 90 printDataVars("shared", op.shared_vars()); 91 printDataVars("copyin", op.copyin_vars()); 92 93 if (auto def = op.default_val()) 94 p << " default(" << def->drop_front(3) << ")"; 95 96 if (auto bind = op.proc_bind_val()) 97 p << " proc_bind(" << bind << ")"; 98 99 p.printRegion(op.getRegion()); 100 } 101 102 /// Emit an error if the same clause is present more than once on an operation. 103 static ParseResult allowedOnce(OpAsmParser &parser, llvm::StringRef clause, 104 llvm::StringRef operation) { 105 return parser.emitError(parser.getNameLoc()) 106 << " at most one " << clause << " clause can appear on the " 107 << operation << " operation"; 108 } 109 110 /// Parses a parallel operation. 111 /// 112 /// operation ::= `omp.parallel` clause-list 113 /// clause-list ::= clause | clause clause-list 114 /// clause ::= if | numThreads | private | firstprivate | shared | copyin | 115 /// default | procBind 116 /// if ::= `if` `(` ssa-id `)` 117 /// numThreads ::= `num_threads` `(` ssa-id-and-type `)` 118 /// private ::= `private` operand-and-type-list 119 /// firstprivate ::= `firstprivate` operand-and-type-list 120 /// shared ::= `shared` operand-and-type-list 121 /// copyin ::= `copyin` operand-and-type-list 122 /// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`) 123 /// procBind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)` 124 /// 125 /// Note that each clause can only appear once in the clase-list. 126 static ParseResult parseParallelOp(OpAsmParser &parser, 127 OperationState &result) { 128 OpAsmParser::OperandType ifCond; 129 std::pair<OpAsmParser::OperandType, Type> numThreads; 130 llvm::SmallVector<OpAsmParser::OperandType, 4> privates; 131 llvm::SmallVector<Type, 4> privateTypes; 132 llvm::SmallVector<OpAsmParser::OperandType, 4> firstprivates; 133 llvm::SmallVector<Type, 4> firstprivateTypes; 134 llvm::SmallVector<OpAsmParser::OperandType, 4> shareds; 135 llvm::SmallVector<Type, 4> sharedTypes; 136 llvm::SmallVector<OpAsmParser::OperandType, 4> copyins; 137 llvm::SmallVector<Type, 4> copyinTypes; 138 std::array<int, 6> segments{0, 0, 0, 0, 0, 0}; 139 llvm::StringRef keyword; 140 bool defaultVal = false; 141 bool procBind = false; 142 143 const int ifClausePos = 0; 144 const int numThreadsClausePos = 1; 145 const int privateClausePos = 2; 146 const int firstprivateClausePos = 3; 147 const int sharedClausePos = 4; 148 const int copyinClausePos = 5; 149 const llvm::StringRef opName = result.name.getStringRef(); 150 151 while (succeeded(parser.parseOptionalKeyword(&keyword))) { 152 if (keyword == "if") { 153 // Fail if there was already another if condition 154 if (segments[ifClausePos]) 155 return allowedOnce(parser, "if", opName); 156 if (parser.parseLParen() || parser.parseOperand(ifCond) || 157 parser.parseRParen()) 158 return failure(); 159 segments[ifClausePos] = 1; 160 } else if (keyword == "num_threads") { 161 // fail if there was already another num_threads clause 162 if (segments[numThreadsClausePos]) 163 return allowedOnce(parser, "num_threads", opName); 164 if (parser.parseLParen() || parser.parseOperand(numThreads.first) || 165 parser.parseColonType(numThreads.second) || parser.parseRParen()) 166 return failure(); 167 segments[numThreadsClausePos] = 1; 168 } else if (keyword == "private") { 169 // fail if there was already another private clause 170 if (segments[privateClausePos]) 171 return allowedOnce(parser, "private", opName); 172 if (parseOperandAndTypeList(parser, privates, privateTypes)) 173 return failure(); 174 segments[privateClausePos] = privates.size(); 175 } else if (keyword == "firstprivate") { 176 // fail if there was already another firstprivate clause 177 if (segments[firstprivateClausePos]) 178 return allowedOnce(parser, "firstprivate", opName); 179 if (parseOperandAndTypeList(parser, firstprivates, firstprivateTypes)) 180 return failure(); 181 segments[firstprivateClausePos] = firstprivates.size(); 182 } else if (keyword == "shared") { 183 // fail if there was already another shared clause 184 if (segments[sharedClausePos]) 185 return allowedOnce(parser, "shared", opName); 186 if (parseOperandAndTypeList(parser, shareds, sharedTypes)) 187 return failure(); 188 segments[sharedClausePos] = shareds.size(); 189 } else if (keyword == "copyin") { 190 // fail if there was already another copyin clause 191 if (segments[copyinClausePos]) 192 return allowedOnce(parser, "copyin", opName); 193 if (parseOperandAndTypeList(parser, copyins, copyinTypes)) 194 return failure(); 195 segments[copyinClausePos] = copyins.size(); 196 } else if (keyword == "default") { 197 // fail if there was already another default clause 198 if (defaultVal) 199 return allowedOnce(parser, "default", opName); 200 defaultVal = true; 201 llvm::StringRef defval; 202 if (parser.parseLParen() || parser.parseKeyword(&defval) || 203 parser.parseRParen()) 204 return failure(); 205 llvm::SmallString<16> attrval; 206 // The def prefix is required for the attribute as "private" is a keyword 207 // in C++ 208 attrval += "def"; 209 attrval += defval; 210 auto attr = parser.getBuilder().getStringAttr(attrval); 211 result.addAttribute("default_val", attr); 212 } else if (keyword == "proc_bind") { 213 // fail if there was already another default clause 214 if (procBind) 215 return allowedOnce(parser, "proc_bind", opName); 216 procBind = true; 217 llvm::StringRef bind; 218 if (parser.parseLParen() || parser.parseKeyword(&bind) || 219 parser.parseRParen()) 220 return failure(); 221 auto attr = parser.getBuilder().getStringAttr(bind); 222 result.addAttribute("proc_bind_val", attr); 223 } else { 224 return parser.emitError(parser.getNameLoc()) 225 << keyword << " is not a valid clause for the " << opName 226 << " operation"; 227 } 228 } 229 230 // Add if parameter 231 if (segments[ifClausePos]) { 232 parser.resolveOperand(ifCond, parser.getBuilder().getI1Type(), 233 result.operands); 234 } 235 236 // Add num_threads parameter 237 if (segments[numThreadsClausePos]) { 238 parser.resolveOperand(numThreads.first, numThreads.second, result.operands); 239 } 240 241 // Add private parameters 242 if (segments[privateClausePos]) { 243 parser.resolveOperands(privates, privateTypes, privates[0].location, 244 result.operands); 245 } 246 247 // Add firstprivate parameters 248 if (segments[firstprivateClausePos]) { 249 parser.resolveOperands(firstprivates, firstprivateTypes, 250 firstprivates[0].location, result.operands); 251 } 252 253 // Add shared parameters 254 if (segments[sharedClausePos]) { 255 parser.resolveOperands(shareds, sharedTypes, shareds[0].location, 256 result.operands); 257 } 258 259 // Add copyin parameters 260 if (segments[copyinClausePos]) { 261 parser.resolveOperands(copyins, copyinTypes, copyins[0].location, 262 result.operands); 263 } 264 265 result.addAttribute("operand_segment_sizes", 266 parser.getBuilder().getI32VectorAttr(segments)); 267 268 Region *body = result.addRegion(); 269 llvm::SmallVector<OpAsmParser::OperandType, 4> regionArgs; 270 llvm::SmallVector<Type, 4> regionArgTypes; 271 if (parser.parseRegion(*body, regionArgs, regionArgTypes)) 272 return failure(); 273 return success(); 274 } 275 276 namespace mlir { 277 namespace omp { 278 #define GET_OP_CLASSES 279 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" 280 } // namespace omp 281 } // namespace mlir 282