1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the OpenMP dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/IR/Attributes.h"
16 #include "mlir/IR/OpImplementation.h"
17 #include "mlir/IR/OperationSupport.h"
18 
19 #include "llvm/ADT/SmallString.h"
20 #include "llvm/ADT/StringRef.h"
21 #include "llvm/ADT/StringSwitch.h"
22 #include <cstddef>
23 
24 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
25 
26 using namespace mlir;
27 using namespace mlir::omp;
28 
29 OpenMPDialect::OpenMPDialect(MLIRContext *context)
30     : Dialect(getDialectNamespace(), context) {
31   addOperations<
32 #define GET_OP_LIST
33 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
34       >();
35 }
36 
37 //===----------------------------------------------------------------------===//
38 // ParallelOp
39 //===----------------------------------------------------------------------===//
40 
41 /// Parse a list of operands with types.
42 ///
43 /// operand-and-type-list ::= `(` ssa-id-and-type-list `)`
44 /// ssa-id-and-type-list ::= ssa-id-and-type |
45 ///                          ssa-id-and-type ',' ssa-id-and-type-list
46 /// ssa-id-and-type ::= ssa-id `:` type
47 static ParseResult
48 parseOperandAndTypeList(OpAsmParser &parser,
49                         SmallVectorImpl<OpAsmParser::OperandType> &operands,
50                         SmallVectorImpl<Type> &types) {
51   if (parser.parseLParen())
52     return failure();
53 
54   do {
55     OpAsmParser::OperandType operand;
56     Type type;
57     if (parser.parseOperand(operand) || parser.parseColonType(type))
58       return failure();
59     operands.push_back(operand);
60     types.push_back(type);
61   } while (succeeded(parser.parseOptionalComma()));
62 
63   if (parser.parseRParen())
64     return failure();
65 
66   return success();
67 }
68 
69 static void printParallelOp(OpAsmPrinter &p, ParallelOp op) {
70   p << "omp.parallel";
71 
72   if (auto ifCond = op.if_expr_var())
73     p << " if(" << ifCond << ")";
74 
75   if (auto threads = op.num_threads_var())
76     p << " num_threads(" << threads << " : " << threads.getType() << ")";
77 
78   // Print private, firstprivate, shared and copyin parameters
79   auto printDataVars = [&p](StringRef name, OperandRange vars) {
80     if (vars.size()) {
81       p << " " << name << "(";
82       for (unsigned i = 0; i < vars.size(); ++i) {
83         std::string separator = i == vars.size() - 1 ? ")" : ", ";
84         p << vars[i] << " : " << vars[i].getType() << separator;
85       }
86     }
87   };
88   printDataVars("private", op.private_vars());
89   printDataVars("firstprivate", op.firstprivate_vars());
90   printDataVars("shared", op.shared_vars());
91   printDataVars("copyin", op.copyin_vars());
92 
93   if (auto def = op.default_val())
94     p << " default(" << def->drop_front(3) << ")";
95 
96   if (auto bind = op.proc_bind_val())
97     p << " proc_bind(" << bind << ")";
98 
99   p.printRegion(op.getRegion());
100 }
101 
102 /// Emit an error if the same clause is present more than once on an operation.
103 static ParseResult allowedOnce(OpAsmParser &parser, llvm::StringRef clause,
104                                llvm::StringRef operation) {
105   return parser.emitError(parser.getNameLoc())
106          << " at most one " << clause << " clause can appear on the "
107          << operation << " operation";
108 }
109 
110 /// Parses a parallel operation.
111 ///
112 /// operation ::= `omp.parallel` clause-list
113 /// clause-list ::= clause | clause clause-list
114 /// clause ::= if | numThreads | private | firstprivate | shared | copyin |
115 ///            default | procBind
116 /// if ::= `if` `(` ssa-id `)`
117 /// numThreads ::= `num_threads` `(` ssa-id-and-type `)`
118 /// private ::= `private` operand-and-type-list
119 /// firstprivate ::= `firstprivate` operand-and-type-list
120 /// shared ::= `shared` operand-and-type-list
121 /// copyin ::= `copyin` operand-and-type-list
122 /// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`)
123 /// procBind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)`
124 ///
125 /// Note that each clause can only appear once in the clase-list.
126 static ParseResult parseParallelOp(OpAsmParser &parser,
127                                    OperationState &result) {
128   OpAsmParser::OperandType ifCond;
129   std::pair<OpAsmParser::OperandType, Type> numThreads;
130   llvm::SmallVector<OpAsmParser::OperandType, 4> privates;
131   llvm::SmallVector<Type, 4> privateTypes;
132   llvm::SmallVector<OpAsmParser::OperandType, 4> firstprivates;
133   llvm::SmallVector<Type, 4> firstprivateTypes;
134   llvm::SmallVector<OpAsmParser::OperandType, 4> shareds;
135   llvm::SmallVector<Type, 4> sharedTypes;
136   llvm::SmallVector<OpAsmParser::OperandType, 4> copyins;
137   llvm::SmallVector<Type, 4> copyinTypes;
138   std::array<int, 6> segments{0, 0, 0, 0, 0, 0};
139   llvm::StringRef keyword;
140   bool defaultVal = false;
141   bool procBind = false;
142 
143   const int ifClausePos = 0;
144   const int numThreadsClausePos = 1;
145   const int privateClausePos = 2;
146   const int firstprivateClausePos = 3;
147   const int sharedClausePos = 4;
148   const int copyinClausePos = 5;
149   const llvm::StringRef opName = result.name.getStringRef();
150 
151   while (succeeded(parser.parseOptionalKeyword(&keyword))) {
152     if (keyword == "if") {
153       // Fail if there was already another if condition
154       if (segments[ifClausePos])
155         return allowedOnce(parser, "if", opName);
156       if (parser.parseLParen() || parser.parseOperand(ifCond) ||
157           parser.parseRParen())
158         return failure();
159       segments[ifClausePos] = 1;
160     } else if (keyword == "num_threads") {
161       // fail if there was already another num_threads clause
162       if (segments[numThreadsClausePos])
163         return allowedOnce(parser, "num_threads", opName);
164       if (parser.parseLParen() || parser.parseOperand(numThreads.first) ||
165           parser.parseColonType(numThreads.second) || parser.parseRParen())
166         return failure();
167       segments[numThreadsClausePos] = 1;
168     } else if (keyword == "private") {
169       // fail if there was already another private clause
170       if (segments[privateClausePos])
171         return allowedOnce(parser, "private", opName);
172       if (parseOperandAndTypeList(parser, privates, privateTypes))
173         return failure();
174       segments[privateClausePos] = privates.size();
175     } else if (keyword == "firstprivate") {
176       // fail if there was already another firstprivate clause
177       if (segments[firstprivateClausePos])
178         return allowedOnce(parser, "firstprivate", opName);
179       if (parseOperandAndTypeList(parser, firstprivates, firstprivateTypes))
180         return failure();
181       segments[firstprivateClausePos] = firstprivates.size();
182     } else if (keyword == "shared") {
183       // fail if there was already another shared clause
184       if (segments[sharedClausePos])
185         return allowedOnce(parser, "shared", opName);
186       if (parseOperandAndTypeList(parser, shareds, sharedTypes))
187         return failure();
188       segments[sharedClausePos] = shareds.size();
189     } else if (keyword == "copyin") {
190       // fail if there was already another copyin clause
191       if (segments[copyinClausePos])
192         return allowedOnce(parser, "copyin", opName);
193       if (parseOperandAndTypeList(parser, copyins, copyinTypes))
194         return failure();
195       segments[copyinClausePos] = copyins.size();
196     } else if (keyword == "default") {
197       // fail if there was already another default clause
198       if (defaultVal)
199         return allowedOnce(parser, "default", opName);
200       defaultVal = true;
201       llvm::StringRef defval;
202       if (parser.parseLParen() || parser.parseKeyword(&defval) ||
203           parser.parseRParen())
204         return failure();
205       llvm::SmallString<16> attrval;
206       // The def prefix is required for the attribute as "private" is a keyword
207       // in C++
208       attrval += "def";
209       attrval += defval;
210       auto attr = parser.getBuilder().getStringAttr(attrval);
211       result.addAttribute("default_val", attr);
212     } else if (keyword == "proc_bind") {
213       // fail if there was already another default clause
214       if (procBind)
215         return allowedOnce(parser, "proc_bind", opName);
216       procBind = true;
217       llvm::StringRef bind;
218       if (parser.parseLParen() || parser.parseKeyword(&bind) ||
219           parser.parseRParen())
220         return failure();
221       auto attr = parser.getBuilder().getStringAttr(bind);
222       result.addAttribute("proc_bind_val", attr);
223     } else {
224       return parser.emitError(parser.getNameLoc())
225              << keyword << " is not a valid clause for the " << opName
226              << " operation";
227     }
228   }
229 
230   // Add if parameter
231   if (segments[ifClausePos]) {
232     parser.resolveOperand(ifCond, parser.getBuilder().getI1Type(),
233                           result.operands);
234   }
235 
236   // Add num_threads parameter
237   if (segments[numThreadsClausePos]) {
238     parser.resolveOperand(numThreads.first, numThreads.second, result.operands);
239   }
240 
241   // Add private parameters
242   if (segments[privateClausePos]) {
243     parser.resolveOperands(privates, privateTypes, privates[0].location,
244                            result.operands);
245   }
246 
247   // Add firstprivate parameters
248   if (segments[firstprivateClausePos]) {
249     parser.resolveOperands(firstprivates, firstprivateTypes,
250                            firstprivates[0].location, result.operands);
251   }
252 
253   // Add shared parameters
254   if (segments[sharedClausePos]) {
255     parser.resolveOperands(shareds, sharedTypes, shareds[0].location,
256                            result.operands);
257   }
258 
259   // Add copyin parameters
260   if (segments[copyinClausePos]) {
261     parser.resolveOperands(copyins, copyinTypes, copyins[0].location,
262                            result.operands);
263   }
264 
265   result.addAttribute("operand_segment_sizes",
266                       parser.getBuilder().getI32VectorAttr(segments));
267 
268   Region *body = result.addRegion();
269   llvm::SmallVector<OpAsmParser::OperandType, 4> regionArgs;
270   llvm::SmallVector<Type, 4> regionArgTypes;
271   if (parser.parseRegion(*body, regionArgs, regionArgTypes))
272     return failure();
273   return success();
274 }
275 
276 namespace mlir {
277 namespace omp {
278 #define GET_OP_CLASSES
279 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
280 } // namespace omp
281 } // namespace mlir
282