1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the OpenMP dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/IR/Attributes.h"
16 #include "mlir/IR/OpImplementation.h"
17 #include "mlir/IR/OperationSupport.h"
18 
19 #include "llvm/ADT/SmallString.h"
20 #include "llvm/ADT/StringRef.h"
21 #include "llvm/ADT/StringSwitch.h"
22 #include <cstddef>
23 
24 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
25 
26 using namespace mlir;
27 using namespace mlir::omp;
28 
29 void OpenMPDialect::initialize() {
30   addOperations<
31 #define GET_OP_LIST
32 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
33       >();
34 }
35 
36 //===----------------------------------------------------------------------===//
37 // ParallelOp
38 //===----------------------------------------------------------------------===//
39 
40 /// Parse a list of operands with types.
41 ///
42 /// operand-and-type-list ::= `(` ssa-id-and-type-list `)`
43 /// ssa-id-and-type-list ::= ssa-id-and-type |
44 ///                          ssa-id-and-type `,` ssa-id-and-type-list
45 /// ssa-id-and-type ::= ssa-id `:` type
46 static ParseResult
47 parseOperandAndTypeList(OpAsmParser &parser,
48                         SmallVectorImpl<OpAsmParser::OperandType> &operands,
49                         SmallVectorImpl<Type> &types) {
50   if (parser.parseLParen())
51     return failure();
52 
53   do {
54     OpAsmParser::OperandType operand;
55     Type type;
56     if (parser.parseOperand(operand) || parser.parseColonType(type))
57       return failure();
58     operands.push_back(operand);
59     types.push_back(type);
60   } while (succeeded(parser.parseOptionalComma()));
61 
62   if (parser.parseRParen())
63     return failure();
64 
65   return success();
66 }
67 
68 /// Parse an allocate clause with allocators and a list of operands with types.
69 ///
70 /// operand-and-type-list ::= `(` allocate-operand-list `)`
71 /// allocate-operand-list :: = allocate-operand |
72 ///                            allocator-operand `,` allocate-operand-list
73 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
74 /// ssa-id-and-type ::= ssa-id `:` type
75 static ParseResult parseAllocateAndAllocator(
76     OpAsmParser &parser,
77     SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate,
78     SmallVectorImpl<Type> &typesAllocate,
79     SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator,
80     SmallVectorImpl<Type> &typesAllocator) {
81   if (parser.parseLParen())
82     return failure();
83 
84   do {
85     OpAsmParser::OperandType operand;
86     Type type;
87 
88     if (parser.parseOperand(operand) || parser.parseColonType(type))
89       return failure();
90     operandsAllocator.push_back(operand);
91     typesAllocator.push_back(type);
92     if (parser.parseArrow())
93       return failure();
94     if (parser.parseOperand(operand) || parser.parseColonType(type))
95       return failure();
96 
97     operandsAllocate.push_back(operand);
98     typesAllocate.push_back(type);
99   } while (succeeded(parser.parseOptionalComma()));
100 
101   if (parser.parseRParen())
102     return failure();
103 
104   return success();
105 }
106 
107 static LogicalResult verifyParallelOp(ParallelOp op) {
108   if (op.allocate_vars().size() != op.allocators_vars().size())
109     return op.emitError(
110         "expected equal sizes for allocate and allocator variables");
111   return success();
112 }
113 
114 static void printParallelOp(OpAsmPrinter &p, ParallelOp op) {
115   p << "omp.parallel";
116 
117   if (auto ifCond = op.if_expr_var())
118     p << " if(" << ifCond << " : " << ifCond.getType() << ")";
119 
120   if (auto threads = op.num_threads_var())
121     p << " num_threads(" << threads << " : " << threads.getType() << ")";
122 
123   // Print private, firstprivate, shared and copyin parameters
124   auto printDataVars = [&p](StringRef name, OperandRange vars) {
125     if (vars.size()) {
126       p << " " << name << "(";
127       for (unsigned i = 0; i < vars.size(); ++i) {
128         std::string separator = i == vars.size() - 1 ? ")" : ", ";
129         p << vars[i] << " : " << vars[i].getType() << separator;
130       }
131     }
132   };
133 
134   // Print allocator and allocate parameters
135   auto printAllocateAndAllocator = [&p](OperandRange varsAllocate,
136                                         OperandRange varsAllocator) {
137     if (varsAllocate.empty())
138       return;
139 
140     p << " allocate(";
141     for (unsigned i = 0; i < varsAllocate.size(); ++i) {
142       std::string separator = i == varsAllocate.size() - 1 ? ")" : ", ";
143       p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> ";
144       p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator;
145     }
146   };
147 
148   printDataVars("private", op.private_vars());
149   printDataVars("firstprivate", op.firstprivate_vars());
150   printDataVars("shared", op.shared_vars());
151   printDataVars("copyin", op.copyin_vars());
152   printAllocateAndAllocator(op.allocate_vars(), op.allocators_vars());
153 
154   if (auto def = op.default_val())
155     p << " default(" << def->drop_front(3) << ")";
156 
157   if (auto bind = op.proc_bind_val())
158     p << " proc_bind(" << bind << ")";
159 
160   p.printRegion(op.getRegion());
161 }
162 
163 /// Emit an error if the same clause is present more than once on an operation.
164 static ParseResult allowedOnce(OpAsmParser &parser, llvm::StringRef clause,
165                                llvm::StringRef operation) {
166   return parser.emitError(parser.getNameLoc())
167          << " at most one " << clause << " clause can appear on the "
168          << operation << " operation";
169 }
170 
171 /// Parses a parallel operation.
172 ///
173 /// operation ::= `omp.parallel` clause-list
174 /// clause-list ::= clause | clause clause-list
175 /// clause ::= if | numThreads | private | firstprivate | shared | copyin |
176 ///            default | procBind
177 /// if ::= `if` `(` ssa-id `)`
178 /// numThreads ::= `num_threads` `(` ssa-id-and-type `)`
179 /// private ::= `private` operand-and-type-list
180 /// firstprivate ::= `firstprivate` operand-and-type-list
181 /// shared ::= `shared` operand-and-type-list
182 /// copyin ::= `copyin` operand-and-type-list
183 /// allocate ::= `allocate` operand-and-type `->` operand-and-type-list
184 /// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`)
185 /// procBind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)`
186 ///
187 /// Note that each clause can only appear once in the clase-list.
188 static ParseResult parseParallelOp(OpAsmParser &parser,
189                                    OperationState &result) {
190   std::pair<OpAsmParser::OperandType, Type> ifCond;
191   std::pair<OpAsmParser::OperandType, Type> numThreads;
192   SmallVector<OpAsmParser::OperandType, 4> privates;
193   SmallVector<Type, 4> privateTypes;
194   SmallVector<OpAsmParser::OperandType, 4> firstprivates;
195   SmallVector<Type, 4> firstprivateTypes;
196   SmallVector<OpAsmParser::OperandType, 4> shareds;
197   SmallVector<Type, 4> sharedTypes;
198   SmallVector<OpAsmParser::OperandType, 4> copyins;
199   SmallVector<Type, 4> copyinTypes;
200   SmallVector<OpAsmParser::OperandType, 4> allocates;
201   SmallVector<Type, 4> allocateTypes;
202   SmallVector<OpAsmParser::OperandType, 4> allocators;
203   SmallVector<Type, 4> allocatorTypes;
204   std::array<int, 8> segments{0, 0, 0, 0, 0, 0, 0, 0};
205   llvm::StringRef keyword;
206   bool defaultVal = false;
207   bool procBind = false;
208 
209   const int ifClausePos = 0;
210   const int numThreadsClausePos = 1;
211   const int privateClausePos = 2;
212   const int firstprivateClausePos = 3;
213   const int sharedClausePos = 4;
214   const int copyinClausePos = 5;
215   const int allocateClausePos = 6;
216   const int allocatorPos = 7;
217   const llvm::StringRef opName = result.name.getStringRef();
218 
219   while (succeeded(parser.parseOptionalKeyword(&keyword))) {
220     if (keyword == "if") {
221       // Fail if there was already another if condition
222       if (segments[ifClausePos])
223         return allowedOnce(parser, "if", opName);
224       if (parser.parseLParen() || parser.parseOperand(ifCond.first) ||
225           parser.parseColonType(ifCond.second) || parser.parseRParen())
226         return failure();
227       segments[ifClausePos] = 1;
228     } else if (keyword == "num_threads") {
229       // fail if there was already another num_threads clause
230       if (segments[numThreadsClausePos])
231         return allowedOnce(parser, "num_threads", opName);
232       if (parser.parseLParen() || parser.parseOperand(numThreads.first) ||
233           parser.parseColonType(numThreads.second) || parser.parseRParen())
234         return failure();
235       segments[numThreadsClausePos] = 1;
236     } else if (keyword == "private") {
237       // fail if there was already another private clause
238       if (segments[privateClausePos])
239         return allowedOnce(parser, "private", opName);
240       if (parseOperandAndTypeList(parser, privates, privateTypes))
241         return failure();
242       segments[privateClausePos] = privates.size();
243     } else if (keyword == "firstprivate") {
244       // fail if there was already another firstprivate clause
245       if (segments[firstprivateClausePos])
246         return allowedOnce(parser, "firstprivate", opName);
247       if (parseOperandAndTypeList(parser, firstprivates, firstprivateTypes))
248         return failure();
249       segments[firstprivateClausePos] = firstprivates.size();
250     } else if (keyword == "shared") {
251       // fail if there was already another shared clause
252       if (segments[sharedClausePos])
253         return allowedOnce(parser, "shared", opName);
254       if (parseOperandAndTypeList(parser, shareds, sharedTypes))
255         return failure();
256       segments[sharedClausePos] = shareds.size();
257     } else if (keyword == "copyin") {
258       // fail if there was already another copyin clause
259       if (segments[copyinClausePos])
260         return allowedOnce(parser, "copyin", opName);
261       if (parseOperandAndTypeList(parser, copyins, copyinTypes))
262         return failure();
263       segments[copyinClausePos] = copyins.size();
264     } else if (keyword == "allocate") {
265       // fail if there was already another allocate clause
266       if (segments[allocateClausePos])
267         return allowedOnce(parser, "allocate", opName);
268       if (parseAllocateAndAllocator(parser, allocates, allocateTypes,
269                                     allocators, allocatorTypes))
270         return failure();
271       segments[allocateClausePos] = allocates.size();
272       segments[allocatorPos] = allocators.size();
273     } else if (keyword == "default") {
274       // fail if there was already another default clause
275       if (defaultVal)
276         return allowedOnce(parser, "default", opName);
277       defaultVal = true;
278       llvm::StringRef defval;
279       if (parser.parseLParen() || parser.parseKeyword(&defval) ||
280           parser.parseRParen())
281         return failure();
282       llvm::SmallString<16> attrval;
283       // The def prefix is required for the attribute as "private" is a keyword
284       // in C++
285       attrval += "def";
286       attrval += defval;
287       auto attr = parser.getBuilder().getStringAttr(attrval);
288       result.addAttribute("default_val", attr);
289     } else if (keyword == "proc_bind") {
290       // fail if there was already another proc_bind clause
291       if (procBind)
292         return allowedOnce(parser, "proc_bind", opName);
293       procBind = true;
294       llvm::StringRef bind;
295       if (parser.parseLParen() || parser.parseKeyword(&bind) ||
296           parser.parseRParen())
297         return failure();
298       auto attr = parser.getBuilder().getStringAttr(bind);
299       result.addAttribute("proc_bind_val", attr);
300     } else {
301       return parser.emitError(parser.getNameLoc())
302              << keyword << " is not a valid clause for the " << opName
303              << " operation";
304     }
305   }
306 
307   // Add if parameter
308   if (segments[ifClausePos] &&
309       parser.resolveOperand(ifCond.first, ifCond.second, result.operands))
310     return failure();
311 
312   // Add num_threads parameter
313   if (segments[numThreadsClausePos] &&
314       parser.resolveOperand(numThreads.first, numThreads.second,
315                             result.operands))
316     return failure();
317 
318   // Add private parameters
319   if (segments[privateClausePos] &&
320       parser.resolveOperands(privates, privateTypes, privates[0].location,
321                              result.operands))
322     return failure();
323 
324   // Add firstprivate parameters
325   if (segments[firstprivateClausePos] &&
326       parser.resolveOperands(firstprivates, firstprivateTypes,
327                              firstprivates[0].location, result.operands))
328     return failure();
329 
330   // Add shared parameters
331   if (segments[sharedClausePos] &&
332       parser.resolveOperands(shareds, sharedTypes, shareds[0].location,
333                              result.operands))
334     return failure();
335 
336   // Add copyin parameters
337   if (segments[copyinClausePos] &&
338       parser.resolveOperands(copyins, copyinTypes, copyins[0].location,
339                              result.operands))
340     return failure();
341 
342   // Add allocate parameters
343   if (segments[allocateClausePos] &&
344       parser.resolveOperands(allocates, allocateTypes, allocates[0].location,
345                              result.operands))
346     return failure();
347 
348   // Add allocator parameters
349   if (segments[allocatorPos] &&
350       parser.resolveOperands(allocators, allocatorTypes, allocators[0].location,
351                              result.operands))
352     return failure();
353 
354   result.addAttribute("operand_segment_sizes",
355                       parser.getBuilder().getI32VectorAttr(segments));
356 
357   Region *body = result.addRegion();
358   SmallVector<OpAsmParser::OperandType, 4> regionArgs;
359   SmallVector<Type, 4> regionArgTypes;
360   if (parser.parseRegion(*body, regionArgs, regionArgTypes))
361     return failure();
362   return success();
363 }
364 
365 #define GET_OP_CLASSES
366 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
367