1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the OpenMP dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/IR/Attributes.h"
16 #include "mlir/IR/OpImplementation.h"
17 #include "mlir/IR/OperationSupport.h"
18 
19 #include "llvm/ADT/SmallString.h"
20 #include "llvm/ADT/StringRef.h"
21 #include "llvm/ADT/StringSwitch.h"
22 #include <cstddef>
23 
24 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
25 
26 using namespace mlir;
27 using namespace mlir::omp;
28 
29 void OpenMPDialect::initialize() {
30   addOperations<
31 #define GET_OP_LIST
32 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
33       >();
34 }
35 
36 //===----------------------------------------------------------------------===//
37 // ParallelOp
38 //===----------------------------------------------------------------------===//
39 
40 /// Parse a list of operands with types.
41 ///
42 /// operand-and-type-list ::= `(` ssa-id-and-type-list `)`
43 /// ssa-id-and-type-list ::= ssa-id-and-type |
44 ///                          ssa-id-and-type ',' ssa-id-and-type-list
45 /// ssa-id-and-type ::= ssa-id `:` type
46 static ParseResult
47 parseOperandAndTypeList(OpAsmParser &parser,
48                         SmallVectorImpl<OpAsmParser::OperandType> &operands,
49                         SmallVectorImpl<Type> &types) {
50   if (parser.parseLParen())
51     return failure();
52 
53   do {
54     OpAsmParser::OperandType operand;
55     Type type;
56     if (parser.parseOperand(operand) || parser.parseColonType(type))
57       return failure();
58     operands.push_back(operand);
59     types.push_back(type);
60   } while (succeeded(parser.parseOptionalComma()));
61 
62   if (parser.parseRParen())
63     return failure();
64 
65   return success();
66 }
67 
68 static void printParallelOp(OpAsmPrinter &p, ParallelOp op) {
69   p << "omp.parallel";
70 
71   if (auto ifCond = op.if_expr_var())
72     p << " if(" << ifCond << " : " << ifCond.getType() << ")";
73 
74   if (auto threads = op.num_threads_var())
75     p << " num_threads(" << threads << " : " << threads.getType() << ")";
76 
77   // Print private, firstprivate, shared and copyin parameters
78   auto printDataVars = [&p](StringRef name, OperandRange vars) {
79     if (vars.size()) {
80       p << " " << name << "(";
81       for (unsigned i = 0; i < vars.size(); ++i) {
82         std::string separator = i == vars.size() - 1 ? ")" : ", ";
83         p << vars[i] << " : " << vars[i].getType() << separator;
84       }
85     }
86   };
87   printDataVars("private", op.private_vars());
88   printDataVars("firstprivate", op.firstprivate_vars());
89   printDataVars("shared", op.shared_vars());
90   printDataVars("copyin", op.copyin_vars());
91 
92   if (auto def = op.default_val())
93     p << " default(" << def->drop_front(3) << ")";
94 
95   if (auto bind = op.proc_bind_val())
96     p << " proc_bind(" << bind << ")";
97 
98   p.printRegion(op.getRegion());
99 }
100 
101 /// Emit an error if the same clause is present more than once on an operation.
102 static ParseResult allowedOnce(OpAsmParser &parser, llvm::StringRef clause,
103                                llvm::StringRef operation) {
104   return parser.emitError(parser.getNameLoc())
105          << " at most one " << clause << " clause can appear on the "
106          << operation << " operation";
107 }
108 
109 /// Parses a parallel operation.
110 ///
111 /// operation ::= `omp.parallel` clause-list
112 /// clause-list ::= clause | clause clause-list
113 /// clause ::= if | numThreads | private | firstprivate | shared | copyin |
114 ///            default | procBind
115 /// if ::= `if` `(` ssa-id `)`
116 /// numThreads ::= `num_threads` `(` ssa-id-and-type `)`
117 /// private ::= `private` operand-and-type-list
118 /// firstprivate ::= `firstprivate` operand-and-type-list
119 /// shared ::= `shared` operand-and-type-list
120 /// copyin ::= `copyin` operand-and-type-list
121 /// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`)
122 /// procBind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)`
123 ///
124 /// Note that each clause can only appear once in the clase-list.
125 static ParseResult parseParallelOp(OpAsmParser &parser,
126                                    OperationState &result) {
127   std::pair<OpAsmParser::OperandType, Type> ifCond;
128   std::pair<OpAsmParser::OperandType, Type> numThreads;
129   llvm::SmallVector<OpAsmParser::OperandType, 4> privates;
130   llvm::SmallVector<Type, 4> privateTypes;
131   llvm::SmallVector<OpAsmParser::OperandType, 4> firstprivates;
132   llvm::SmallVector<Type, 4> firstprivateTypes;
133   llvm::SmallVector<OpAsmParser::OperandType, 4> shareds;
134   llvm::SmallVector<Type, 4> sharedTypes;
135   llvm::SmallVector<OpAsmParser::OperandType, 4> copyins;
136   llvm::SmallVector<Type, 4> copyinTypes;
137   std::array<int, 6> segments{0, 0, 0, 0, 0, 0};
138   llvm::StringRef keyword;
139   bool defaultVal = false;
140   bool procBind = false;
141 
142   const int ifClausePos = 0;
143   const int numThreadsClausePos = 1;
144   const int privateClausePos = 2;
145   const int firstprivateClausePos = 3;
146   const int sharedClausePos = 4;
147   const int copyinClausePos = 5;
148   const llvm::StringRef opName = result.name.getStringRef();
149 
150   while (succeeded(parser.parseOptionalKeyword(&keyword))) {
151     if (keyword == "if") {
152       // Fail if there was already another if condition
153       if (segments[ifClausePos])
154         return allowedOnce(parser, "if", opName);
155       if (parser.parseLParen() || parser.parseOperand(ifCond.first) ||
156           parser.parseColonType(ifCond.second) || parser.parseRParen())
157         return failure();
158       segments[ifClausePos] = 1;
159     } else if (keyword == "num_threads") {
160       // fail if there was already another num_threads clause
161       if (segments[numThreadsClausePos])
162         return allowedOnce(parser, "num_threads", opName);
163       if (parser.parseLParen() || parser.parseOperand(numThreads.first) ||
164           parser.parseColonType(numThreads.second) || parser.parseRParen())
165         return failure();
166       segments[numThreadsClausePos] = 1;
167     } else if (keyword == "private") {
168       // fail if there was already another private clause
169       if (segments[privateClausePos])
170         return allowedOnce(parser, "private", opName);
171       if (parseOperandAndTypeList(parser, privates, privateTypes))
172         return failure();
173       segments[privateClausePos] = privates.size();
174     } else if (keyword == "firstprivate") {
175       // fail if there was already another firstprivate clause
176       if (segments[firstprivateClausePos])
177         return allowedOnce(parser, "firstprivate", opName);
178       if (parseOperandAndTypeList(parser, firstprivates, firstprivateTypes))
179         return failure();
180       segments[firstprivateClausePos] = firstprivates.size();
181     } else if (keyword == "shared") {
182       // fail if there was already another shared clause
183       if (segments[sharedClausePos])
184         return allowedOnce(parser, "shared", opName);
185       if (parseOperandAndTypeList(parser, shareds, sharedTypes))
186         return failure();
187       segments[sharedClausePos] = shareds.size();
188     } else if (keyword == "copyin") {
189       // fail if there was already another copyin clause
190       if (segments[copyinClausePos])
191         return allowedOnce(parser, "copyin", opName);
192       if (parseOperandAndTypeList(parser, copyins, copyinTypes))
193         return failure();
194       segments[copyinClausePos] = copyins.size();
195     } else if (keyword == "default") {
196       // fail if there was already another default clause
197       if (defaultVal)
198         return allowedOnce(parser, "default", opName);
199       defaultVal = true;
200       llvm::StringRef defval;
201       if (parser.parseLParen() || parser.parseKeyword(&defval) ||
202           parser.parseRParen())
203         return failure();
204       llvm::SmallString<16> attrval;
205       // The def prefix is required for the attribute as "private" is a keyword
206       // in C++
207       attrval += "def";
208       attrval += defval;
209       auto attr = parser.getBuilder().getStringAttr(attrval);
210       result.addAttribute("default_val", attr);
211     } else if (keyword == "proc_bind") {
212       // fail if there was already another proc_bind clause
213       if (procBind)
214         return allowedOnce(parser, "proc_bind", opName);
215       procBind = true;
216       llvm::StringRef bind;
217       if (parser.parseLParen() || parser.parseKeyword(&bind) ||
218           parser.parseRParen())
219         return failure();
220       auto attr = parser.getBuilder().getStringAttr(bind);
221       result.addAttribute("proc_bind_val", attr);
222     } else {
223       return parser.emitError(parser.getNameLoc())
224              << keyword << " is not a valid clause for the " << opName
225              << " operation";
226     }
227   }
228 
229   // Add if parameter
230   if (segments[ifClausePos]) {
231     parser.resolveOperand(ifCond.first, ifCond.second, result.operands);
232   }
233 
234   // Add num_threads parameter
235   if (segments[numThreadsClausePos]) {
236     parser.resolveOperand(numThreads.first, numThreads.second, result.operands);
237   }
238 
239   // Add private parameters
240   if (segments[privateClausePos]) {
241     parser.resolveOperands(privates, privateTypes, privates[0].location,
242                            result.operands);
243   }
244 
245   // Add firstprivate parameters
246   if (segments[firstprivateClausePos]) {
247     parser.resolveOperands(firstprivates, firstprivateTypes,
248                            firstprivates[0].location, result.operands);
249   }
250 
251   // Add shared parameters
252   if (segments[sharedClausePos]) {
253     parser.resolveOperands(shareds, sharedTypes, shareds[0].location,
254                            result.operands);
255   }
256 
257   // Add copyin parameters
258   if (segments[copyinClausePos]) {
259     parser.resolveOperands(copyins, copyinTypes, copyins[0].location,
260                            result.operands);
261   }
262 
263   result.addAttribute("operand_segment_sizes",
264                       parser.getBuilder().getI32VectorAttr(segments));
265 
266   Region *body = result.addRegion();
267   llvm::SmallVector<OpAsmParser::OperandType, 4> regionArgs;
268   llvm::SmallVector<Type, 4> regionArgTypes;
269   if (parser.parseRegion(*body, regionArgs, regionArgTypes))
270     return failure();
271   return success();
272 }
273 
274 namespace mlir {
275 namespace omp {
276 #define GET_OP_CLASSES
277 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
278 } // namespace omp
279 } // namespace mlir
280