1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the OpenMP dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/OpImplementation.h"
18 #include "mlir/IR/OperationSupport.h"
19 
20 #include "llvm/ADT/BitVector.h"
21 #include "llvm/ADT/SmallString.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/ADT/StringSwitch.h"
25 #include <cstddef>
26 
27 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
28 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
29 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
30 
31 using namespace mlir;
32 using namespace mlir::omp;
33 
34 namespace {
35 /// Model for pointer-like types that already provide a `getElementType` method.
36 template <typename T>
37 struct PointerLikeModel
38     : public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> {
39   Type getElementType(Type pointer) const {
40     return pointer.cast<T>().getElementType();
41   }
42 };
43 } // end namespace
44 
45 void OpenMPDialect::initialize() {
46   addOperations<
47 #define GET_OP_LIST
48 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
49       >();
50 
51   LLVM::LLVMPointerType::attachInterface<
52       PointerLikeModel<LLVM::LLVMPointerType>>(*getContext());
53   MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext());
54 }
55 
56 //===----------------------------------------------------------------------===//
57 // ParallelOp
58 //===----------------------------------------------------------------------===//
59 
60 void ParallelOp::build(OpBuilder &builder, OperationState &state,
61                        ArrayRef<NamedAttribute> attributes) {
62   ParallelOp::build(
63       builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
64       /*default_val=*/nullptr, /*private_vars=*/ValueRange(),
65       /*firstprivate_vars=*/ValueRange(), /*shared_vars=*/ValueRange(),
66       /*copyin_vars=*/ValueRange(), /*allocate_vars=*/ValueRange(),
67       /*allocators_vars=*/ValueRange(), /*proc_bind_val=*/nullptr);
68   state.addAttributes(attributes);
69 }
70 
71 //===----------------------------------------------------------------------===//
72 // Parser and printer for Operand and type list
73 //===----------------------------------------------------------------------===//
74 
75 /// Parse a list of operands with types.
76 ///
77 /// operand-and-type-list ::= `(` ssa-id-and-type-list `)`
78 /// ssa-id-and-type-list ::= ssa-id-and-type |
79 ///                          ssa-id-and-type `,` ssa-id-and-type-list
80 /// ssa-id-and-type ::= ssa-id `:` type
81 static ParseResult
82 parseOperandAndTypeList(OpAsmParser &parser,
83                         SmallVectorImpl<OpAsmParser::OperandType> &operands,
84                         SmallVectorImpl<Type> &types) {
85   return parser.parseCommaSeparatedList(
86       OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
87         OpAsmParser::OperandType operand;
88         Type type;
89         if (parser.parseOperand(operand) || parser.parseColonType(type))
90           return failure();
91         operands.push_back(operand);
92         types.push_back(type);
93         return success();
94       });
95 }
96 
97 /// Print an operand and type list with parentheses
98 static void printOperandAndTypeList(OpAsmPrinter &p, OperandRange operands) {
99   p << "(";
100   llvm::interleaveComma(
101       operands, p, [&](const Value &v) { p << v << " : " << v.getType(); });
102   p << ") ";
103 }
104 
105 /// Print data variables corresponding to a data-sharing clause `name`
106 static void printDataVars(OpAsmPrinter &p, OperandRange operands,
107                           StringRef name) {
108   if (operands.size()) {
109     p << name;
110     printOperandAndTypeList(p, operands);
111   }
112 }
113 
114 //===----------------------------------------------------------------------===//
115 // Parser and printer for Allocate Clause
116 //===----------------------------------------------------------------------===//
117 
118 /// Parse an allocate clause with allocators and a list of operands with types.
119 ///
120 /// allocate ::= `allocate` `(` allocate-operand-list `)`
121 /// allocate-operand-list :: = allocate-operand |
122 ///                            allocator-operand `,` allocate-operand-list
123 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
124 /// ssa-id-and-type ::= ssa-id `:` type
125 static ParseResult parseAllocateAndAllocator(
126     OpAsmParser &parser,
127     SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate,
128     SmallVectorImpl<Type> &typesAllocate,
129     SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator,
130     SmallVectorImpl<Type> &typesAllocator) {
131 
132   return parser.parseCommaSeparatedList(
133       OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
134         OpAsmParser::OperandType operand;
135         Type type;
136         if (parser.parseOperand(operand) || parser.parseColonType(type))
137           return failure();
138         operandsAllocator.push_back(operand);
139         typesAllocator.push_back(type);
140         if (parser.parseArrow())
141           return failure();
142         if (parser.parseOperand(operand) || parser.parseColonType(type))
143           return failure();
144 
145         operandsAllocate.push_back(operand);
146         typesAllocate.push_back(type);
147         return success();
148       });
149 }
150 
151 /// Print allocate clause
152 static void printAllocateAndAllocator(OpAsmPrinter &p,
153                                       OperandRange varsAllocate,
154                                       OperandRange varsAllocator) {
155   if (varsAllocate.empty())
156     return;
157 
158   p << "allocate(";
159   for (unsigned i = 0; i < varsAllocate.size(); ++i) {
160     std::string separator = i == varsAllocate.size() - 1 ? ") " : ", ";
161     p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> ";
162     p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator;
163   }
164 }
165 
166 static LogicalResult verifyParallelOp(ParallelOp op) {
167   if (op.allocate_vars().size() != op.allocators_vars().size())
168     return op.emitError(
169         "expected equal sizes for allocate and allocator variables");
170   return success();
171 }
172 
173 static void printParallelOp(OpAsmPrinter &p, ParallelOp op) {
174   p << " ";
175   if (auto ifCond = op.if_expr_var())
176     p << "if(" << ifCond << " : " << ifCond.getType() << ") ";
177 
178   if (auto threads = op.num_threads_var())
179     p << "num_threads(" << threads << " : " << threads.getType() << ") ";
180 
181   printDataVars(p, op.private_vars(), "private");
182   printDataVars(p, op.firstprivate_vars(), "firstprivate");
183   printDataVars(p, op.shared_vars(), "shared");
184   printDataVars(p, op.copyin_vars(), "copyin");
185   printAllocateAndAllocator(p, op.allocate_vars(), op.allocators_vars());
186 
187   if (auto def = op.default_val())
188     p << "default(" << def->drop_front(3) << ") ";
189 
190   if (auto bind = op.proc_bind_val())
191     p << "proc_bind(" << bind << ") ";
192 
193   p.printRegion(op.getRegion());
194 }
195 
196 //===----------------------------------------------------------------------===//
197 // Parser and printer for Linear Clause
198 //===----------------------------------------------------------------------===//
199 
200 /// linear ::= `linear` `(` linear-list `)`
201 /// linear-list := linear-val | linear-val linear-list
202 /// linear-val := ssa-id-and-type `=` ssa-id-and-type
203 static ParseResult
204 parseLinearClause(OpAsmParser &parser,
205                   SmallVectorImpl<OpAsmParser::OperandType> &vars,
206                   SmallVectorImpl<Type> &types,
207                   SmallVectorImpl<OpAsmParser::OperandType> &stepVars) {
208   if (parser.parseLParen())
209     return failure();
210 
211   do {
212     OpAsmParser::OperandType var;
213     Type type;
214     OpAsmParser::OperandType stepVar;
215     if (parser.parseOperand(var) || parser.parseEqual() ||
216         parser.parseOperand(stepVar) || parser.parseColonType(type))
217       return failure();
218 
219     vars.push_back(var);
220     types.push_back(type);
221     stepVars.push_back(stepVar);
222   } while (succeeded(parser.parseOptionalComma()));
223 
224   if (parser.parseRParen())
225     return failure();
226 
227   return success();
228 }
229 
230 /// Print Linear Clause
231 static void printLinearClause(OpAsmPrinter &p, OperandRange linearVars,
232                               OperandRange linearStepVars) {
233   size_t linearVarsSize = linearVars.size();
234   p << "(";
235   for (unsigned i = 0; i < linearVarsSize; ++i) {
236     std::string separator = i == linearVarsSize - 1 ? ") " : ", ";
237     p << linearVars[i];
238     if (linearStepVars.size() > i)
239       p << " = " << linearStepVars[i];
240     p << " : " << linearVars[i].getType() << separator;
241   }
242 }
243 
244 //===----------------------------------------------------------------------===//
245 // Parser and printer for Schedule Clause
246 //===----------------------------------------------------------------------===//
247 
248 /// schedule ::= `schedule` `(` sched-list `)`
249 /// sched-list ::= sched-val | sched-val sched-list
250 /// sched-val ::= sched-with-chunk | sched-wo-chunk
251 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
252 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
253 /// sched-wo-chunk ::=  `auto` | `runtime`
254 static ParseResult
255 parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule,
256                     Optional<OpAsmParser::OperandType> &chunkSize) {
257   if (parser.parseLParen())
258     return failure();
259 
260   StringRef keyword;
261   if (parser.parseKeyword(&keyword))
262     return failure();
263 
264   schedule = keyword;
265   if (keyword == "static" || keyword == "dynamic" || keyword == "guided") {
266     if (succeeded(parser.parseOptionalEqual())) {
267       chunkSize = OpAsmParser::OperandType{};
268       if (parser.parseOperand(*chunkSize))
269         return failure();
270     } else {
271       chunkSize = llvm::NoneType::None;
272     }
273   } else if (keyword == "auto" || keyword == "runtime") {
274     chunkSize = llvm::NoneType::None;
275   } else {
276     return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
277   }
278 
279   if (parser.parseRParen())
280     return failure();
281 
282   return success();
283 }
284 
285 /// Print schedule clause
286 static void printScheduleClause(OpAsmPrinter &p, StringRef &sched,
287                                 Value scheduleChunkVar) {
288   std::string schedLower = sched.lower();
289   p << "(" << schedLower;
290   if (scheduleChunkVar)
291     p << " = " << scheduleChunkVar;
292   p << ") ";
293 }
294 
295 //===----------------------------------------------------------------------===//
296 // Parser, printer and verifier for ReductionVarList
297 //===----------------------------------------------------------------------===//
298 
299 /// reduction ::= `reduction` `(` reduction-entry-list `)`
300 /// reduction-entry-list ::= reduction-entry
301 ///                        | reduction-entry-list `,` reduction-entry
302 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type
303 static ParseResult
304 parseReductionVarList(OpAsmParser &parser,
305                       SmallVectorImpl<SymbolRefAttr> &symbols,
306                       SmallVectorImpl<OpAsmParser::OperandType> &operands,
307                       SmallVectorImpl<Type> &types) {
308   if (failed(parser.parseLParen()))
309     return failure();
310 
311   do {
312     if (parser.parseAttribute(symbols.emplace_back()) || parser.parseArrow() ||
313         parser.parseOperand(operands.emplace_back()) ||
314         parser.parseColonType(types.emplace_back()))
315       return failure();
316   } while (succeeded(parser.parseOptionalComma()));
317   return parser.parseRParen();
318 }
319 
320 /// Print Reduction clause
321 static void printReductionVarList(OpAsmPrinter &p,
322                                   Optional<ArrayAttr> reductions,
323                                   OperandRange reduction_vars) {
324   for (unsigned i = 0, e = reductions->size(); i < e; ++i) {
325     if (i != 0)
326       p << ", ";
327     p << (*reductions)[i] << " -> " << reduction_vars[i] << " : "
328       << reduction_vars[i].getType();
329   }
330   p << ") ";
331 }
332 
333 /// Verifies Reduction Clause
334 static LogicalResult verifyReductionVarList(Operation *op,
335                                             Optional<ArrayAttr> reductions,
336                                             OperandRange reduction_vars) {
337   if (reduction_vars.size() != 0) {
338     if (!reductions || reductions->size() != reduction_vars.size())
339       return op->emitOpError()
340              << "expected as many reduction symbol references "
341                 "as reduction variables";
342   } else {
343     if (reductions)
344       return op->emitOpError() << "unexpected reduction symbol references";
345     return success();
346   }
347 
348   DenseSet<Value> accumulators;
349   for (auto args : llvm::zip(reduction_vars, *reductions)) {
350     Value accum = std::get<0>(args);
351 
352     if (!accumulators.insert(accum).second)
353       return op->emitOpError() << "accumulator variable used more than once";
354 
355     Type varType = accum.getType().cast<PointerLikeType>();
356     auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>();
357     auto decl =
358         SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef);
359     if (!decl)
360       return op->emitOpError() << "expected symbol reference " << symbolRef
361                                << " to point to a reduction declaration";
362 
363     if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
364       return op->emitOpError()
365              << "expected accumulator (" << varType
366              << ") to be the same type as reduction declaration ("
367              << decl.getAccumulatorType() << ")";
368   }
369 
370   return success();
371 }
372 
373 //===----------------------------------------------------------------------===//
374 // Parser, printer and verifier for Synchronization Hint (2.17.12)
375 //===----------------------------------------------------------------------===//
376 
377 /// Parses a Synchronization Hint clause. The value of hint is an integer
378 /// which is a combination of different hints from `omp_sync_hint_t`.
379 ///
380 /// hint-clause = `hint` `(` hint-value `)`
381 static ParseResult parseSynchronizationHint(OpAsmParser &parser,
382                                             IntegerAttr &hintAttr) {
383   if (failed(parser.parseOptionalKeyword("hint"))) {
384     hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
385     return success();
386   }
387 
388   if (failed(parser.parseLParen()))
389     return failure();
390   StringRef hintKeyword;
391   int64_t hint = 0;
392   do {
393     if (failed(parser.parseKeyword(&hintKeyword)))
394       return failure();
395     if (hintKeyword == "uncontended")
396       hint |= 1;
397     else if (hintKeyword == "contended")
398       hint |= 2;
399     else if (hintKeyword == "nonspeculative")
400       hint |= 4;
401     else if (hintKeyword == "speculative")
402       hint |= 8;
403     else
404       return parser.emitError(parser.getCurrentLocation())
405              << hintKeyword << " is not a valid hint";
406   } while (succeeded(parser.parseOptionalComma()));
407   if (failed(parser.parseRParen()))
408     return failure();
409   hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
410   return success();
411 }
412 
413 /// Prints a Synchronization Hint clause
414 static void printSynchronizationHint(OpAsmPrinter &p, Operation *op,
415                                      IntegerAttr hintAttr) {
416   int64_t hint = hintAttr.getInt();
417 
418   if (hint == 0)
419     return;
420 
421   // Helper function to get n-th bit from the right end of `value`
422   auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
423 
424   bool uncontended = bitn(hint, 0);
425   bool contended = bitn(hint, 1);
426   bool nonspeculative = bitn(hint, 2);
427   bool speculative = bitn(hint, 3);
428 
429   SmallVector<StringRef> hints;
430   if (uncontended)
431     hints.push_back("uncontended");
432   if (contended)
433     hints.push_back("contended");
434   if (nonspeculative)
435     hints.push_back("nonspeculative");
436   if (speculative)
437     hints.push_back("speculative");
438 
439   p << "hint(";
440   llvm::interleaveComma(hints, p);
441   p << ")";
442 }
443 
444 /// Verifies a synchronization hint clause
445 static LogicalResult verifySynchronizationHint(Operation *op, int32_t hint) {
446 
447   // Helper function to get n-th bit from the right end of `value`
448   auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
449 
450   bool uncontended = bitn(hint, 0);
451   bool contended = bitn(hint, 1);
452   bool nonspeculative = bitn(hint, 2);
453   bool speculative = bitn(hint, 3);
454 
455   if (uncontended && contended)
456     return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
457                                 "omp_sync_hint_contended cannot be combined";
458   if (nonspeculative && speculative)
459     return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
460                                 "omp_sync_hint_speculative cannot be combined.";
461   return success();
462 }
463 
464 enum ClauseType {
465   ifClause,
466   numThreadsClause,
467   privateClause,
468   firstprivateClause,
469   lastprivateClause,
470   sharedClause,
471   copyinClause,
472   allocateClause,
473   defaultClause,
474   procBindClause,
475   reductionClause,
476   nowaitClause,
477   linearClause,
478   scheduleClause,
479   collapseClause,
480   orderClause,
481   orderedClause,
482   inclusiveClause,
483   COUNT
484 };
485 
486 //===----------------------------------------------------------------------===//
487 // Parser for Clause List
488 //===----------------------------------------------------------------------===//
489 
490 /// Parse a list of clauses. The clauses can appear in any order, but their
491 /// operand segment indices are in the same order that they are passed in the
492 /// `clauses` list. The operand segments are added over the prevSegments
493 
494 /// clause-list ::= clause clause-list | empty
495 /// clause ::= if | num-threads | private | firstprivate | lastprivate |
496 ///            shared | copyin | allocate | default | proc-bind | reduction |
497 ///            nowait | linear | schedule | collapse | order | ordered |
498 ///            inclusive
499 /// if ::= `if` `(` ssa-id-and-type `)`
500 /// num-threads ::= `num_threads` `(` ssa-id-and-type `)`
501 /// private ::= `private` operand-and-type-list
502 /// firstprivate ::= `firstprivate` operand-and-type-list
503 /// lastprivate ::= `lastprivate` operand-and-type-list
504 /// shared ::= `shared` operand-and-type-list
505 /// copyin ::= `copyin` operand-and-type-list
506 /// allocate ::= `allocate` `(` allocate-operand-list `)`
507 /// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`)
508 /// proc-bind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)`
509 /// reduction ::= `reduction` `(` reduction-entry-list `)`
510 /// nowait ::= `nowait`
511 /// linear ::= `linear` `(` linear-list `)`
512 /// schedule ::= `schedule` `(` sched-list `)`
513 /// collapse ::= `collapse` `(` ssa-id-and-type `)`
514 /// order ::= `order` `(` `concurrent` `)`
515 /// ordered ::= `ordered` `(` ssa-id-and-type `)`
516 /// inclusive ::= `inclusive`
517 ///
518 /// Note that each clause can only appear once in the clase-list.
519 static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
520                                 SmallVectorImpl<ClauseType> &clauses,
521                                 SmallVectorImpl<int> &segments) {
522 
523   // Check done[clause] to see if it has been parsed already
524   llvm::BitVector done(ClauseType::COUNT, false);
525 
526   // See pos[clause] to get position of clause in operand segments
527   SmallVector<int> pos(ClauseType::COUNT, -1);
528 
529   // Stores the last parsed clause keyword
530   StringRef clauseKeyword;
531   StringRef opName = result.name.getStringRef();
532 
533   // Containers for storing operands, types and attributes for various clauses
534   std::pair<OpAsmParser::OperandType, Type> ifCond;
535   std::pair<OpAsmParser::OperandType, Type> numThreads;
536 
537   SmallVector<OpAsmParser::OperandType> privates, firstprivates, lastprivates,
538       shareds, copyins;
539   SmallVector<Type> privateTypes, firstprivateTypes, lastprivateTypes,
540       sharedTypes, copyinTypes;
541 
542   SmallVector<OpAsmParser::OperandType> allocates, allocators;
543   SmallVector<Type> allocateTypes, allocatorTypes;
544 
545   SmallVector<SymbolRefAttr> reductionSymbols;
546   SmallVector<OpAsmParser::OperandType> reductionVars;
547   SmallVector<Type> reductionVarTypes;
548 
549   SmallVector<OpAsmParser::OperandType> linears;
550   SmallVector<Type> linearTypes;
551   SmallVector<OpAsmParser::OperandType> linearSteps;
552 
553   SmallString<8> schedule;
554   Optional<OpAsmParser::OperandType> scheduleChunkSize;
555 
556   // Compute the position of clauses in operand segments
557   int currPos = 0;
558   for (ClauseType clause : clauses) {
559 
560     // Skip the following clauses - they do not take any position in operand
561     // segments
562     if (clause == defaultClause || clause == procBindClause ||
563         clause == nowaitClause || clause == collapseClause ||
564         clause == orderClause || clause == orderedClause ||
565         clause == inclusiveClause)
566       continue;
567 
568     pos[clause] = currPos++;
569 
570     // For the following clauses, two positions are reserved in the operand
571     // segments
572     if (clause == allocateClause || clause == linearClause)
573       currPos++;
574   }
575 
576   SmallVector<int> clauseSegments(currPos);
577 
578   // Helper function to check if a clause is allowed/repeated or not
579   auto checkAllowed = [&](ClauseType clause,
580                           bool allowRepeat = false) -> ParseResult {
581     if (!llvm::is_contained(clauses, clause))
582       return parser.emitError(parser.getCurrentLocation())
583              << clauseKeyword << "is not a valid clause for the " << opName
584              << " operation";
585     if (done[clause] && !allowRepeat)
586       return parser.emitError(parser.getCurrentLocation())
587              << "at most one " << clauseKeyword << " clause can appear on the "
588              << opName << " operation";
589     done[clause] = true;
590     return success();
591   };
592 
593   while (succeeded(parser.parseOptionalKeyword(&clauseKeyword))) {
594     if (clauseKeyword == "if") {
595       if (checkAllowed(ifClause) || parser.parseLParen() ||
596           parser.parseOperand(ifCond.first) ||
597           parser.parseColonType(ifCond.second) || parser.parseRParen())
598         return failure();
599       clauseSegments[pos[ifClause]] = 1;
600     } else if (clauseKeyword == "num_threads") {
601       if (checkAllowed(numThreadsClause) || parser.parseLParen() ||
602           parser.parseOperand(numThreads.first) ||
603           parser.parseColonType(numThreads.second) || parser.parseRParen())
604         return failure();
605       clauseSegments[pos[numThreadsClause]] = 1;
606     } else if (clauseKeyword == "private") {
607       if (checkAllowed(privateClause) ||
608           parseOperandAndTypeList(parser, privates, privateTypes))
609         return failure();
610       clauseSegments[pos[privateClause]] = privates.size();
611     } else if (clauseKeyword == "firstprivate") {
612       if (checkAllowed(firstprivateClause) ||
613           parseOperandAndTypeList(parser, firstprivates, firstprivateTypes))
614         return failure();
615       clauseSegments[pos[firstprivateClause]] = firstprivates.size();
616     } else if (clauseKeyword == "lastprivate") {
617       if (checkAllowed(lastprivateClause) ||
618           parseOperandAndTypeList(parser, lastprivates, lastprivateTypes))
619         return failure();
620       clauseSegments[pos[lastprivateClause]] = lastprivates.size();
621     } else if (clauseKeyword == "shared") {
622       if (checkAllowed(sharedClause) ||
623           parseOperandAndTypeList(parser, shareds, sharedTypes))
624         return failure();
625       clauseSegments[pos[sharedClause]] = shareds.size();
626     } else if (clauseKeyword == "copyin") {
627       if (checkAllowed(copyinClause) ||
628           parseOperandAndTypeList(parser, copyins, copyinTypes))
629         return failure();
630       clauseSegments[pos[copyinClause]] = copyins.size();
631     } else if (clauseKeyword == "allocate") {
632       if (checkAllowed(allocateClause) ||
633           parseAllocateAndAllocator(parser, allocates, allocateTypes,
634                                     allocators, allocatorTypes))
635         return failure();
636       clauseSegments[pos[allocateClause]] = allocates.size();
637       clauseSegments[pos[allocateClause] + 1] = allocators.size();
638     } else if (clauseKeyword == "default") {
639       StringRef defval;
640       if (checkAllowed(defaultClause) || parser.parseLParen() ||
641           parser.parseKeyword(&defval) || parser.parseRParen())
642         return failure();
643       // The def prefix is required for the attribute as "private" is a keyword
644       // in C++.
645       auto attr = parser.getBuilder().getStringAttr("def" + defval);
646       result.addAttribute("default_val", attr);
647     } else if (clauseKeyword == "proc_bind") {
648       StringRef bind;
649       if (checkAllowed(procBindClause) || parser.parseLParen() ||
650           parser.parseKeyword(&bind) || parser.parseRParen())
651         return failure();
652       auto attr = parser.getBuilder().getStringAttr(bind);
653       result.addAttribute("proc_bind_val", attr);
654     } else if (clauseKeyword == "reduction") {
655       if (checkAllowed(reductionClause) ||
656           parseReductionVarList(parser, reductionSymbols, reductionVars,
657                                 reductionVarTypes))
658         return failure();
659       clauseSegments[pos[reductionClause]] = reductionVars.size();
660     } else if (clauseKeyword == "nowait") {
661       if (checkAllowed(nowaitClause))
662         return failure();
663       auto attr = UnitAttr::get(parser.getBuilder().getContext());
664       result.addAttribute("nowait", attr);
665     } else if (clauseKeyword == "linear") {
666       if (checkAllowed(linearClause) ||
667           parseLinearClause(parser, linears, linearTypes, linearSteps))
668         return failure();
669       clauseSegments[pos[linearClause]] = linears.size();
670       clauseSegments[pos[linearClause] + 1] = linearSteps.size();
671     } else if (clauseKeyword == "schedule") {
672       if (checkAllowed(scheduleClause) ||
673           parseScheduleClause(parser, schedule, scheduleChunkSize))
674         return failure();
675       if (scheduleChunkSize) {
676         clauseSegments[pos[scheduleClause]] = 1;
677       }
678     } else if (clauseKeyword == "collapse") {
679       auto type = parser.getBuilder().getI64Type();
680       mlir::IntegerAttr attr;
681       if (checkAllowed(collapseClause) || parser.parseLParen() ||
682           parser.parseAttribute(attr, type) || parser.parseRParen())
683         return failure();
684       result.addAttribute("collapse_val", attr);
685     } else if (clauseKeyword == "ordered") {
686       mlir::IntegerAttr attr;
687       if (checkAllowed(orderedClause))
688         return failure();
689       if (succeeded(parser.parseOptionalLParen())) {
690         auto type = parser.getBuilder().getI64Type();
691         if (parser.parseAttribute(attr, type) || parser.parseRParen())
692           return failure();
693       } else {
694         // Use 0 to represent no ordered parameter was specified
695         attr = parser.getBuilder().getI64IntegerAttr(0);
696       }
697       result.addAttribute("ordered_val", attr);
698     } else if (clauseKeyword == "order") {
699       StringRef order;
700       if (checkAllowed(orderClause) || parser.parseLParen() ||
701           parser.parseKeyword(&order) || parser.parseRParen())
702         return failure();
703       auto attr = parser.getBuilder().getStringAttr(order);
704       result.addAttribute("order", attr);
705     } else if (clauseKeyword == "inclusive") {
706       if (checkAllowed(inclusiveClause))
707         return failure();
708       auto attr = UnitAttr::get(parser.getBuilder().getContext());
709       result.addAttribute("inclusive", attr);
710     } else {
711       return parser.emitError(parser.getNameLoc())
712              << clauseKeyword << " is not a valid clause";
713     }
714   }
715 
716   // Add if parameter.
717   if (done[ifClause] && clauseSegments[pos[ifClause]] &&
718       failed(
719           parser.resolveOperand(ifCond.first, ifCond.second, result.operands)))
720     return failure();
721 
722   // Add num_threads parameter.
723   if (done[numThreadsClause] && clauseSegments[pos[numThreadsClause]] &&
724       failed(parser.resolveOperand(numThreads.first, numThreads.second,
725                                    result.operands)))
726     return failure();
727 
728   // Add private parameters.
729   if (done[privateClause] && clauseSegments[pos[privateClause]] &&
730       failed(parser.resolveOperands(privates, privateTypes,
731                                     privates[0].location, result.operands)))
732     return failure();
733 
734   // Add firstprivate parameters.
735   if (done[firstprivateClause] && clauseSegments[pos[firstprivateClause]] &&
736       failed(parser.resolveOperands(firstprivates, firstprivateTypes,
737                                     firstprivates[0].location,
738                                     result.operands)))
739     return failure();
740 
741   // Add lastprivate parameters.
742   if (done[lastprivateClause] && clauseSegments[pos[lastprivateClause]] &&
743       failed(parser.resolveOperands(lastprivates, lastprivateTypes,
744                                     lastprivates[0].location, result.operands)))
745     return failure();
746 
747   // Add shared parameters.
748   if (done[sharedClause] && clauseSegments[pos[sharedClause]] &&
749       failed(parser.resolveOperands(shareds, sharedTypes, shareds[0].location,
750                                     result.operands)))
751     return failure();
752 
753   // Add copyin parameters.
754   if (done[copyinClause] && clauseSegments[pos[copyinClause]] &&
755       failed(parser.resolveOperands(copyins, copyinTypes, copyins[0].location,
756                                     result.operands)))
757     return failure();
758 
759   // Add allocate parameters.
760   if (done[allocateClause] && clauseSegments[pos[allocateClause]] &&
761       failed(parser.resolveOperands(allocates, allocateTypes,
762                                     allocates[0].location, result.operands)))
763     return failure();
764 
765   // Add allocator parameters.
766   if (done[allocateClause] && clauseSegments[pos[allocateClause] + 1] &&
767       failed(parser.resolveOperands(allocators, allocatorTypes,
768                                     allocators[0].location, result.operands)))
769     return failure();
770 
771   // Add reduction parameters and symbols
772   if (done[reductionClause] && clauseSegments[pos[reductionClause]]) {
773     if (failed(parser.resolveOperands(reductionVars, reductionVarTypes,
774                                       parser.getNameLoc(), result.operands)))
775       return failure();
776 
777     SmallVector<Attribute> reductions(reductionSymbols.begin(),
778                                       reductionSymbols.end());
779     result.addAttribute("reductions",
780                         parser.getBuilder().getArrayAttr(reductions));
781   }
782 
783   // Add linear parameters
784   if (done[linearClause] && clauseSegments[pos[linearClause]]) {
785     auto linearStepType = parser.getBuilder().getI32Type();
786     SmallVector<Type> linearStepTypes(linearSteps.size(), linearStepType);
787     if (failed(parser.resolveOperands(linears, linearTypes, linears[0].location,
788                                       result.operands)) ||
789         failed(parser.resolveOperands(linearSteps, linearStepTypes,
790                                       linearSteps[0].location,
791                                       result.operands)))
792       return failure();
793   }
794 
795   // Add schedule parameters
796   if (done[scheduleClause] && !schedule.empty()) {
797     schedule[0] = llvm::toUpper(schedule[0]);
798     auto attr = parser.getBuilder().getStringAttr(schedule);
799     result.addAttribute("schedule_val", attr);
800     if (scheduleChunkSize) {
801       auto chunkSizeType = parser.getBuilder().getI32Type();
802       parser.resolveOperand(*scheduleChunkSize, chunkSizeType, result.operands);
803     }
804   }
805 
806   segments.insert(segments.end(), clauseSegments.begin(), clauseSegments.end());
807 
808   return success();
809 }
810 
811 /// Parses a parallel operation.
812 ///
813 /// operation ::= `omp.parallel` clause-list
814 /// clause-list ::= clause | clause clause-list
815 /// clause ::= if | num-threads | private | firstprivate | shared | copyin |
816 ///            allocate | default | proc-bind
817 ///
818 static ParseResult parseParallelOp(OpAsmParser &parser,
819                                    OperationState &result) {
820   SmallVector<ClauseType> clauses = {
821       ifClause,           numThreadsClause, privateClause,
822       firstprivateClause, sharedClause,     copyinClause,
823       allocateClause,     defaultClause,    procBindClause};
824 
825   SmallVector<int> segments;
826 
827   if (failed(parseClauses(parser, result, clauses, segments)))
828     return failure();
829 
830   result.addAttribute("operand_segment_sizes",
831                       parser.getBuilder().getI32VectorAttr(segments));
832 
833   Region *body = result.addRegion();
834   SmallVector<OpAsmParser::OperandType> regionArgs;
835   SmallVector<Type> regionArgTypes;
836   if (parser.parseRegion(*body, regionArgs, regionArgTypes))
837     return failure();
838   return success();
839 }
840 
841 /// Parses an OpenMP Workshare Loop operation
842 ///
843 /// wsloop ::= `omp.wsloop` loop-control clause-list
844 /// loop-control ::= `(` ssa-id-list `)` `:` type `=`  loop-bounds
845 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` steps
846 /// steps := `step` `(`ssa-id-list`)`
847 /// clause-list ::= clause clause-list | empty
848 /// clause ::= private | firstprivate | lastprivate | linear | schedule |
849 //             collapse | nowait | ordered | order | inclusive | reduction
850 static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) {
851 
852   // Parse an opening `(` followed by induction variables followed by `)`
853   SmallVector<OpAsmParser::OperandType> ivs;
854   if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
855                                      OpAsmParser::Delimiter::Paren))
856     return failure();
857 
858   int numIVs = static_cast<int>(ivs.size());
859   Type loopVarType;
860   if (parser.parseColonType(loopVarType))
861     return failure();
862 
863   // Parse loop bounds.
864   SmallVector<OpAsmParser::OperandType> lower;
865   if (parser.parseEqual() ||
866       parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) ||
867       parser.resolveOperands(lower, loopVarType, result.operands))
868     return failure();
869 
870   SmallVector<OpAsmParser::OperandType> upper;
871   if (parser.parseKeyword("to") ||
872       parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) ||
873       parser.resolveOperands(upper, loopVarType, result.operands))
874     return failure();
875 
876   // Parse step values.
877   SmallVector<OpAsmParser::OperandType> steps;
878   if (parser.parseKeyword("step") ||
879       parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) ||
880       parser.resolveOperands(steps, loopVarType, result.operands))
881     return failure();
882 
883   SmallVector<ClauseType> clauses = {
884       privateClause,   firstprivateClause, lastprivateClause, linearClause,
885       reductionClause, collapseClause,     orderClause,       orderedClause,
886       nowaitClause,    scheduleClause};
887   SmallVector<int> segments{numIVs, numIVs, numIVs};
888   if (failed(parseClauses(parser, result, clauses, segments)))
889     return failure();
890 
891   result.addAttribute("operand_segment_sizes",
892                       parser.getBuilder().getI32VectorAttr(segments));
893 
894   // Now parse the body.
895   Region *body = result.addRegion();
896   SmallVector<Type> ivTypes(numIVs, loopVarType);
897   SmallVector<OpAsmParser::OperandType> blockArgs(ivs);
898   if (parser.parseRegion(*body, blockArgs, ivTypes))
899     return failure();
900   return success();
901 }
902 
903 static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) {
904   auto args = op.getRegion().front().getArguments();
905   p << " (" << args << ") : " << args[0].getType() << " = (" << op.lowerBound()
906     << ") to (" << op.upperBound() << ") step (" << op.step() << ") ";
907 
908   printDataVars(p, op.private_vars(), "private");
909   printDataVars(p, op.firstprivate_vars(), "firstprivate");
910   printDataVars(p, op.lastprivate_vars(), "lastprivate");
911 
912   if (op.linear_vars().size()) {
913     p << "linear";
914     printLinearClause(p, op.linear_vars(), op.linear_step_vars());
915   }
916 
917   if (auto sched = op.schedule_val()) {
918     p << "schedule";
919     printScheduleClause(p, sched.getValue(), op.schedule_chunk_var());
920   }
921 
922   if (auto collapse = op.collapse_val())
923     p << "collapse(" << collapse << ") ";
924 
925   if (op.nowait())
926     p << "nowait ";
927 
928   if (auto ordered = op.ordered_val())
929     p << "ordered(" << ordered << ") ";
930 
931   if (!op.reduction_vars().empty()) {
932     p << "reduction(";
933     printReductionVarList(p, op.reductions(), op.reduction_vars());
934   }
935 
936   if (op.inclusive()) {
937     p << "inclusive ";
938   }
939 
940   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
941 }
942 
943 //===----------------------------------------------------------------------===//
944 // ReductionOp
945 //===----------------------------------------------------------------------===//
946 
947 static ParseResult parseAtomicReductionRegion(OpAsmParser &parser,
948                                               Region &region) {
949   if (parser.parseOptionalKeyword("atomic"))
950     return success();
951   return parser.parseRegion(region);
952 }
953 
954 static void printAtomicReductionRegion(OpAsmPrinter &printer,
955                                        ReductionDeclareOp op, Region &region) {
956   if (region.empty())
957     return;
958   printer << "atomic ";
959   printer.printRegion(region);
960 }
961 
962 static LogicalResult verifyReductionDeclareOp(ReductionDeclareOp op) {
963   if (op.initializerRegion().empty())
964     return op.emitOpError() << "expects non-empty initializer region";
965   Block &initializerEntryBlock = op.initializerRegion().front();
966   if (initializerEntryBlock.getNumArguments() != 1 ||
967       initializerEntryBlock.getArgument(0).getType() != op.type()) {
968     return op.emitOpError() << "expects initializer region with one argument "
969                                "of the reduction type";
970   }
971 
972   for (YieldOp yieldOp : op.initializerRegion().getOps<YieldOp>()) {
973     if (yieldOp.results().size() != 1 ||
974         yieldOp.results().getTypes()[0] != op.type())
975       return op.emitOpError() << "expects initializer region to yield a value "
976                                  "of the reduction type";
977   }
978 
979   if (op.reductionRegion().empty())
980     return op.emitOpError() << "expects non-empty reduction region";
981   Block &reductionEntryBlock = op.reductionRegion().front();
982   if (reductionEntryBlock.getNumArguments() != 2 ||
983       reductionEntryBlock.getArgumentTypes()[0] !=
984           reductionEntryBlock.getArgumentTypes()[1] ||
985       reductionEntryBlock.getArgumentTypes()[0] != op.type())
986     return op.emitOpError() << "expects reduction region with two arguments of "
987                                "the reduction type";
988   for (YieldOp yieldOp : op.reductionRegion().getOps<YieldOp>()) {
989     if (yieldOp.results().size() != 1 ||
990         yieldOp.results().getTypes()[0] != op.type())
991       return op.emitOpError() << "expects reduction region to yield a value "
992                                  "of the reduction type";
993   }
994 
995   if (op.atomicReductionRegion().empty())
996     return success();
997 
998   Block &atomicReductionEntryBlock = op.atomicReductionRegion().front();
999   if (atomicReductionEntryBlock.getNumArguments() != 2 ||
1000       atomicReductionEntryBlock.getArgumentTypes()[0] !=
1001           atomicReductionEntryBlock.getArgumentTypes()[1])
1002     return op.emitOpError() << "expects atomic reduction region with two "
1003                                "arguments of the same type";
1004   auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0]
1005                      .dyn_cast<PointerLikeType>();
1006   if (!ptrType || ptrType.getElementType() != op.type())
1007     return op.emitOpError() << "expects atomic reduction region arguments to "
1008                                "be accumulators containing the reduction type";
1009   return success();
1010 }
1011 
1012 static LogicalResult verifyReductionOp(ReductionOp op) {
1013   // TODO: generalize this to an op interface when there is more than one op
1014   // that supports reductions.
1015   auto container = op->getParentOfType<WsLoopOp>();
1016   for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i)
1017     if (container.reduction_vars()[i] == op.accumulator())
1018       return success();
1019 
1020   return op.emitOpError() << "the accumulator is not used by the parent";
1021 }
1022 
1023 //===----------------------------------------------------------------------===//
1024 // WsLoopOp
1025 //===----------------------------------------------------------------------===//
1026 
1027 void WsLoopOp::build(OpBuilder &builder, OperationState &state,
1028                      ValueRange lowerBound, ValueRange upperBound,
1029                      ValueRange step, ArrayRef<NamedAttribute> attributes) {
1030   build(builder, state, TypeRange(), lowerBound, upperBound, step,
1031         /*private_vars=*/ValueRange(),
1032         /*firstprivate_vars=*/ValueRange(), /*lastprivate_vars=*/ValueRange(),
1033         /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
1034         /*reduction_vars=*/ValueRange(), /*schedule_val=*/nullptr,
1035         /*schedule_chunk_var=*/nullptr, /*collapse_val=*/nullptr,
1036         /*nowait=*/nullptr, /*ordered_val=*/nullptr, /*order_val=*/nullptr,
1037         /*inclusive=*/nullptr, /*buildBody=*/false);
1038   state.addAttributes(attributes);
1039 }
1040 
1041 void WsLoopOp::build(OpBuilder &, OperationState &state, TypeRange resultTypes,
1042                      ValueRange operands, ArrayRef<NamedAttribute> attributes) {
1043   state.addOperands(operands);
1044   state.addAttributes(attributes);
1045   (void)state.addRegion();
1046   assert(resultTypes.empty() && "mismatched number of return types");
1047   state.addTypes(resultTypes);
1048 }
1049 
1050 void WsLoopOp::build(OpBuilder &builder, OperationState &result,
1051                      TypeRange typeRange, ValueRange lowerBounds,
1052                      ValueRange upperBounds, ValueRange steps,
1053                      ValueRange privateVars, ValueRange firstprivateVars,
1054                      ValueRange lastprivateVars, ValueRange linearVars,
1055                      ValueRange linearStepVars, ValueRange reductionVars,
1056                      StringAttr scheduleVal, Value scheduleChunkVar,
1057                      IntegerAttr collapseVal, UnitAttr nowait,
1058                      IntegerAttr orderedVal, StringAttr orderVal,
1059                      UnitAttr inclusive, bool buildBody) {
1060   result.addOperands(lowerBounds);
1061   result.addOperands(upperBounds);
1062   result.addOperands(steps);
1063   result.addOperands(privateVars);
1064   result.addOperands(firstprivateVars);
1065   result.addOperands(linearVars);
1066   result.addOperands(linearStepVars);
1067   if (scheduleChunkVar)
1068     result.addOperands(scheduleChunkVar);
1069 
1070   if (scheduleVal)
1071     result.addAttribute("schedule_val", scheduleVal);
1072   if (collapseVal)
1073     result.addAttribute("collapse_val", collapseVal);
1074   if (nowait)
1075     result.addAttribute("nowait", nowait);
1076   if (orderedVal)
1077     result.addAttribute("ordered_val", orderedVal);
1078   if (orderVal)
1079     result.addAttribute("order", orderVal);
1080   if (inclusive)
1081     result.addAttribute("inclusive", inclusive);
1082   result.addAttribute(
1083       WsLoopOp::getOperandSegmentSizeAttr(),
1084       builder.getI32VectorAttr(
1085           {static_cast<int32_t>(lowerBounds.size()),
1086            static_cast<int32_t>(upperBounds.size()),
1087            static_cast<int32_t>(steps.size()),
1088            static_cast<int32_t>(privateVars.size()),
1089            static_cast<int32_t>(firstprivateVars.size()),
1090            static_cast<int32_t>(lastprivateVars.size()),
1091            static_cast<int32_t>(linearVars.size()),
1092            static_cast<int32_t>(linearStepVars.size()),
1093            static_cast<int32_t>(reductionVars.size()),
1094            static_cast<int32_t>(scheduleChunkVar != nullptr ? 1 : 0)}));
1095 
1096   Region *bodyRegion = result.addRegion();
1097   if (buildBody) {
1098     OpBuilder::InsertionGuard guard(builder);
1099     unsigned numIVs = steps.size();
1100     SmallVector<Type, 8> argTypes(numIVs, steps.getType().front());
1101     builder.createBlock(bodyRegion, {}, argTypes);
1102   }
1103 }
1104 
1105 static LogicalResult verifyWsLoopOp(WsLoopOp op) {
1106   return verifyReductionVarList(op, op.reductions(), op.reduction_vars());
1107 }
1108 
1109 //===----------------------------------------------------------------------===//
1110 // Verifier for critical construct (2.17.1)
1111 //===----------------------------------------------------------------------===//
1112 
1113 static LogicalResult verifyCriticalDeclareOp(CriticalDeclareOp op) {
1114   return verifySynchronizationHint(op, op.hint());
1115 }
1116 
1117 static LogicalResult verifyCriticalOp(CriticalOp op) {
1118 
1119   if (op.nameAttr()) {
1120     auto symbolRef = op.nameAttr().cast<SymbolRefAttr>();
1121     auto decl =
1122         SymbolTable::lookupNearestSymbolFrom<CriticalDeclareOp>(op, symbolRef);
1123     if (!decl) {
1124       return op.emitOpError() << "expected symbol reference " << symbolRef
1125                               << " to point to a critical declaration";
1126     }
1127   }
1128 
1129   return success();
1130 }
1131 
1132 //===----------------------------------------------------------------------===//
1133 // Verifier for ordered construct
1134 //===----------------------------------------------------------------------===//
1135 
1136 static LogicalResult verifyOrderedOp(OrderedOp op) {
1137   auto container = op->getParentOfType<WsLoopOp>();
1138   if (!container || !container.ordered_valAttr() ||
1139       container.ordered_valAttr().getInt() == 0)
1140     return op.emitOpError() << "ordered depend directive must be closely "
1141                             << "nested inside a worksharing-loop with ordered "
1142                             << "clause with parameter present";
1143 
1144   if (container.ordered_valAttr().getInt() !=
1145       (int64_t)op.num_loops_val().getValue())
1146     return op.emitOpError() << "number of variables in depend clause does not "
1147                             << "match number of iteration variables in the "
1148                             << "doacross loop";
1149 
1150   return success();
1151 }
1152 
1153 static LogicalResult verifyOrderedRegionOp(OrderedRegionOp op) {
1154   // TODO: The code generation for ordered simd directive is not supported yet.
1155   if (op.simd())
1156     return failure();
1157 
1158   if (auto container = op->getParentOfType<WsLoopOp>()) {
1159     if (!container.ordered_valAttr() ||
1160         container.ordered_valAttr().getInt() != 0)
1161       return op.emitOpError() << "ordered region must be closely nested inside "
1162                               << "a worksharing-loop region with an ordered "
1163                               << "clause without parameter present";
1164   }
1165 
1166   return success();
1167 }
1168 
1169 #define GET_OP_CLASSES
1170 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
1171