1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the OpenMP dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/OpImplementation.h"
18 #include "mlir/IR/OperationSupport.h"
19 
20 #include "llvm/ADT/BitVector.h"
21 #include "llvm/ADT/SmallString.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/ADT/StringSwitch.h"
25 #include <cstddef>
26 
27 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
28 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
29 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
30 
31 using namespace mlir;
32 using namespace mlir::omp;
33 
34 namespace {
35 /// Model for pointer-like types that already provide a `getElementType` method.
36 template <typename T>
37 struct PointerLikeModel
38     : public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> {
39   Type getElementType(Type pointer) const {
40     return pointer.cast<T>().getElementType();
41   }
42 };
43 } // end namespace
44 
45 void OpenMPDialect::initialize() {
46   addOperations<
47 #define GET_OP_LIST
48 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
49       >();
50 
51   LLVM::LLVMPointerType::attachInterface<
52       PointerLikeModel<LLVM::LLVMPointerType>>(*getContext());
53   MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext());
54 }
55 
56 //===----------------------------------------------------------------------===//
57 // ParallelOp
58 //===----------------------------------------------------------------------===//
59 
60 void ParallelOp::build(OpBuilder &builder, OperationState &state,
61                        ArrayRef<NamedAttribute> attributes) {
62   ParallelOp::build(
63       builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
64       /*default_val=*/nullptr, /*private_vars=*/ValueRange(),
65       /*firstprivate_vars=*/ValueRange(), /*shared_vars=*/ValueRange(),
66       /*copyin_vars=*/ValueRange(), /*allocate_vars=*/ValueRange(),
67       /*allocators_vars=*/ValueRange(), /*proc_bind_val=*/nullptr);
68   state.addAttributes(attributes);
69 }
70 
71 //===----------------------------------------------------------------------===//
72 // Parser and printer for Operand and type list
73 //===----------------------------------------------------------------------===//
74 
75 /// Parse a list of operands with types.
76 ///
77 /// operand-and-type-list ::= `(` ssa-id-and-type-list `)`
78 /// ssa-id-and-type-list ::= ssa-id-and-type |
79 ///                          ssa-id-and-type `,` ssa-id-and-type-list
80 /// ssa-id-and-type ::= ssa-id `:` type
81 static ParseResult
82 parseOperandAndTypeList(OpAsmParser &parser,
83                         SmallVectorImpl<OpAsmParser::OperandType> &operands,
84                         SmallVectorImpl<Type> &types) {
85   return parser.parseCommaSeparatedList(
86       OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
87         OpAsmParser::OperandType operand;
88         Type type;
89         if (parser.parseOperand(operand) || parser.parseColonType(type))
90           return failure();
91         operands.push_back(operand);
92         types.push_back(type);
93         return success();
94       });
95 }
96 
97 /// Print an operand and type list with parentheses
98 static void printOperandAndTypeList(OpAsmPrinter &p, OperandRange operands) {
99   p << "(";
100   llvm::interleaveComma(
101       operands, p, [&](const Value &v) { p << v << " : " << v.getType(); });
102   p << ") ";
103 }
104 
105 /// Print data variables corresponding to a data-sharing clause `name`
106 static void printDataVars(OpAsmPrinter &p, OperandRange operands,
107                           StringRef name) {
108   if (operands.size()) {
109     p << name;
110     printOperandAndTypeList(p, operands);
111   }
112 }
113 
114 //===----------------------------------------------------------------------===//
115 // Parser and printer for Allocate Clause
116 //===----------------------------------------------------------------------===//
117 
118 /// Parse an allocate clause with allocators and a list of operands with types.
119 ///
120 /// allocate ::= `allocate` `(` allocate-operand-list `)`
121 /// allocate-operand-list :: = allocate-operand |
122 ///                            allocator-operand `,` allocate-operand-list
123 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
124 /// ssa-id-and-type ::= ssa-id `:` type
125 static ParseResult parseAllocateAndAllocator(
126     OpAsmParser &parser,
127     SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate,
128     SmallVectorImpl<Type> &typesAllocate,
129     SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator,
130     SmallVectorImpl<Type> &typesAllocator) {
131 
132   return parser.parseCommaSeparatedList(
133       OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
134         OpAsmParser::OperandType operand;
135         Type type;
136         if (parser.parseOperand(operand) || parser.parseColonType(type))
137           return failure();
138         operandsAllocator.push_back(operand);
139         typesAllocator.push_back(type);
140         if (parser.parseArrow())
141           return failure();
142         if (parser.parseOperand(operand) || parser.parseColonType(type))
143           return failure();
144 
145         operandsAllocate.push_back(operand);
146         typesAllocate.push_back(type);
147         return success();
148       });
149 }
150 
151 /// Print allocate clause
152 static void printAllocateAndAllocator(OpAsmPrinter &p,
153                                       OperandRange varsAllocate,
154                                       OperandRange varsAllocator) {
155   p << "allocate(";
156   for (unsigned i = 0; i < varsAllocate.size(); ++i) {
157     std::string separator = i == varsAllocate.size() - 1 ? ") " : ", ";
158     p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> ";
159     p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator;
160   }
161 }
162 
163 static LogicalResult verifyParallelOp(ParallelOp op) {
164   if (op.allocate_vars().size() != op.allocators_vars().size())
165     return op.emitError(
166         "expected equal sizes for allocate and allocator variables");
167   return success();
168 }
169 
170 static void printParallelOp(OpAsmPrinter &p, ParallelOp op) {
171   p << " ";
172   if (auto ifCond = op.if_expr_var())
173     p << "if(" << ifCond << " : " << ifCond.getType() << ") ";
174 
175   if (auto threads = op.num_threads_var())
176     p << "num_threads(" << threads << " : " << threads.getType() << ") ";
177 
178   printDataVars(p, op.private_vars(), "private");
179   printDataVars(p, op.firstprivate_vars(), "firstprivate");
180   printDataVars(p, op.shared_vars(), "shared");
181   printDataVars(p, op.copyin_vars(), "copyin");
182 
183   if (!op.allocate_vars().empty())
184     printAllocateAndAllocator(p, op.allocate_vars(), op.allocators_vars());
185 
186   if (auto def = op.default_val())
187     p << "default(" << def->drop_front(3) << ") ";
188 
189   if (auto bind = op.proc_bind_val())
190     p << "proc_bind(" << bind << ") ";
191 
192   p.printRegion(op.getRegion());
193 }
194 
195 //===----------------------------------------------------------------------===//
196 // Parser and printer for Linear Clause
197 //===----------------------------------------------------------------------===//
198 
199 /// linear ::= `linear` `(` linear-list `)`
200 /// linear-list := linear-val | linear-val linear-list
201 /// linear-val := ssa-id-and-type `=` ssa-id-and-type
202 static ParseResult
203 parseLinearClause(OpAsmParser &parser,
204                   SmallVectorImpl<OpAsmParser::OperandType> &vars,
205                   SmallVectorImpl<Type> &types,
206                   SmallVectorImpl<OpAsmParser::OperandType> &stepVars) {
207   if (parser.parseLParen())
208     return failure();
209 
210   do {
211     OpAsmParser::OperandType var;
212     Type type;
213     OpAsmParser::OperandType stepVar;
214     if (parser.parseOperand(var) || parser.parseEqual() ||
215         parser.parseOperand(stepVar) || parser.parseColonType(type))
216       return failure();
217 
218     vars.push_back(var);
219     types.push_back(type);
220     stepVars.push_back(stepVar);
221   } while (succeeded(parser.parseOptionalComma()));
222 
223   if (parser.parseRParen())
224     return failure();
225 
226   return success();
227 }
228 
229 /// Print Linear Clause
230 static void printLinearClause(OpAsmPrinter &p, OperandRange linearVars,
231                               OperandRange linearStepVars) {
232   size_t linearVarsSize = linearVars.size();
233   p << "linear(";
234   for (unsigned i = 0; i < linearVarsSize; ++i) {
235     std::string separator = i == linearVarsSize - 1 ? ") " : ", ";
236     p << linearVars[i];
237     if (linearStepVars.size() > i)
238       p << " = " << linearStepVars[i];
239     p << " : " << linearVars[i].getType() << separator;
240   }
241 }
242 
243 //===----------------------------------------------------------------------===//
244 // Parser and printer for Schedule Clause
245 //===----------------------------------------------------------------------===//
246 
247 /// schedule ::= `schedule` `(` sched-list `)`
248 /// sched-list ::= sched-val | sched-val sched-list
249 /// sched-val ::= sched-with-chunk | sched-wo-chunk
250 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
251 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
252 /// sched-wo-chunk ::=  `auto` | `runtime`
253 static ParseResult
254 parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule,
255                     SmallVectorImpl<SmallString<12>> &modifiers,
256                     Optional<OpAsmParser::OperandType> &chunkSize) {
257   if (parser.parseLParen())
258     return failure();
259 
260   StringRef keyword;
261   if (parser.parseKeyword(&keyword))
262     return failure();
263 
264   schedule = keyword;
265   if (keyword == "static" || keyword == "dynamic" || keyword == "guided") {
266     if (succeeded(parser.parseOptionalEqual())) {
267       chunkSize = OpAsmParser::OperandType{};
268       if (parser.parseOperand(*chunkSize))
269         return failure();
270     } else {
271       chunkSize = llvm::NoneType::None;
272     }
273   } else if (keyword == "auto" || keyword == "runtime") {
274     chunkSize = llvm::NoneType::None;
275   } else {
276     return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
277   }
278 
279   // If there is a comma, we have one or more modifiers..
280   if (succeeded(parser.parseOptionalComma())) {
281     StringRef mod;
282     if (parser.parseKeyword(&mod))
283       return failure();
284     modifiers.push_back(mod);
285   }
286 
287   if (parser.parseRParen())
288     return failure();
289 
290   return success();
291 }
292 
293 /// Print schedule clause
294 static void printScheduleClause(OpAsmPrinter &p, StringRef &sched,
295                                 llvm::Optional<StringRef> modifier,
296                                 Value scheduleChunkVar) {
297   std::string schedLower = sched.lower();
298   p << "schedule(" << schedLower;
299   if (scheduleChunkVar)
300     p << " = " << scheduleChunkVar;
301   if (modifier && modifier.getValue() != "none")
302     p << ", " << modifier;
303   p << ") ";
304 }
305 
306 //===----------------------------------------------------------------------===//
307 // Parser, printer and verifier for ReductionVarList
308 //===----------------------------------------------------------------------===//
309 
310 /// reduction ::= `reduction` `(` reduction-entry-list `)`
311 /// reduction-entry-list ::= reduction-entry
312 ///                        | reduction-entry-list `,` reduction-entry
313 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type
314 static ParseResult
315 parseReductionVarList(OpAsmParser &parser,
316                       SmallVectorImpl<SymbolRefAttr> &symbols,
317                       SmallVectorImpl<OpAsmParser::OperandType> &operands,
318                       SmallVectorImpl<Type> &types) {
319   if (failed(parser.parseLParen()))
320     return failure();
321 
322   do {
323     if (parser.parseAttribute(symbols.emplace_back()) || parser.parseArrow() ||
324         parser.parseOperand(operands.emplace_back()) ||
325         parser.parseColonType(types.emplace_back()))
326       return failure();
327   } while (succeeded(parser.parseOptionalComma()));
328   return parser.parseRParen();
329 }
330 
331 /// Print Reduction clause
332 static void printReductionVarList(OpAsmPrinter &p,
333                                   Optional<ArrayAttr> reductions,
334                                   OperandRange reduction_vars) {
335   p << "reduction(";
336   for (unsigned i = 0, e = reductions->size(); i < e; ++i) {
337     if (i != 0)
338       p << ", ";
339     p << (*reductions)[i] << " -> " << reduction_vars[i] << " : "
340       << reduction_vars[i].getType();
341   }
342   p << ") ";
343 }
344 
345 /// Verifies Reduction Clause
346 static LogicalResult verifyReductionVarList(Operation *op,
347                                             Optional<ArrayAttr> reductions,
348                                             OperandRange reduction_vars) {
349   if (reduction_vars.size() != 0) {
350     if (!reductions || reductions->size() != reduction_vars.size())
351       return op->emitOpError()
352              << "expected as many reduction symbol references "
353                 "as reduction variables";
354   } else {
355     if (reductions)
356       return op->emitOpError() << "unexpected reduction symbol references";
357     return success();
358   }
359 
360   DenseSet<Value> accumulators;
361   for (auto args : llvm::zip(reduction_vars, *reductions)) {
362     Value accum = std::get<0>(args);
363 
364     if (!accumulators.insert(accum).second)
365       return op->emitOpError() << "accumulator variable used more than once";
366 
367     Type varType = accum.getType().cast<PointerLikeType>();
368     auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>();
369     auto decl =
370         SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef);
371     if (!decl)
372       return op->emitOpError() << "expected symbol reference " << symbolRef
373                                << " to point to a reduction declaration";
374 
375     if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
376       return op->emitOpError()
377              << "expected accumulator (" << varType
378              << ") to be the same type as reduction declaration ("
379              << decl.getAccumulatorType() << ")";
380   }
381 
382   return success();
383 }
384 
385 //===----------------------------------------------------------------------===//
386 // Parser, printer and verifier for Synchronization Hint (2.17.12)
387 //===----------------------------------------------------------------------===//
388 
389 /// Parses a Synchronization Hint clause. The value of hint is an integer
390 /// which is a combination of different hints from `omp_sync_hint_t`.
391 ///
392 /// hint-clause = `hint` `(` hint-value `)`
393 static ParseResult parseSynchronizationHint(OpAsmParser &parser,
394                                             IntegerAttr &hintAttr,
395                                             bool parseKeyword = true) {
396   if (parseKeyword && failed(parser.parseOptionalKeyword("hint"))) {
397     hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
398     return success();
399   }
400 
401   if (failed(parser.parseLParen()))
402     return failure();
403   StringRef hintKeyword;
404   int64_t hint = 0;
405   do {
406     if (failed(parser.parseKeyword(&hintKeyword)))
407       return failure();
408     if (hintKeyword == "uncontended")
409       hint |= 1;
410     else if (hintKeyword == "contended")
411       hint |= 2;
412     else if (hintKeyword == "nonspeculative")
413       hint |= 4;
414     else if (hintKeyword == "speculative")
415       hint |= 8;
416     else
417       return parser.emitError(parser.getCurrentLocation())
418              << hintKeyword << " is not a valid hint";
419   } while (succeeded(parser.parseOptionalComma()));
420   if (failed(parser.parseRParen()))
421     return failure();
422   hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
423   return success();
424 }
425 
426 /// Prints a Synchronization Hint clause
427 static void printSynchronizationHint(OpAsmPrinter &p, Operation *op,
428                                      IntegerAttr hintAttr) {
429   int64_t hint = hintAttr.getInt();
430 
431   if (hint == 0)
432     return;
433 
434   // Helper function to get n-th bit from the right end of `value`
435   auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
436 
437   bool uncontended = bitn(hint, 0);
438   bool contended = bitn(hint, 1);
439   bool nonspeculative = bitn(hint, 2);
440   bool speculative = bitn(hint, 3);
441 
442   SmallVector<StringRef> hints;
443   if (uncontended)
444     hints.push_back("uncontended");
445   if (contended)
446     hints.push_back("contended");
447   if (nonspeculative)
448     hints.push_back("nonspeculative");
449   if (speculative)
450     hints.push_back("speculative");
451 
452   p << "hint(";
453   llvm::interleaveComma(hints, p);
454   p << ") ";
455 }
456 
457 /// Verifies a synchronization hint clause
458 static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
459 
460   // Helper function to get n-th bit from the right end of `value`
461   auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
462 
463   bool uncontended = bitn(hint, 0);
464   bool contended = bitn(hint, 1);
465   bool nonspeculative = bitn(hint, 2);
466   bool speculative = bitn(hint, 3);
467 
468   if (uncontended && contended)
469     return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
470                                 "omp_sync_hint_contended cannot be combined";
471   if (nonspeculative && speculative)
472     return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
473                                 "omp_sync_hint_speculative cannot be combined.";
474   return success();
475 }
476 
477 enum ClauseType {
478   ifClause,
479   numThreadsClause,
480   privateClause,
481   firstprivateClause,
482   lastprivateClause,
483   sharedClause,
484   copyinClause,
485   allocateClause,
486   defaultClause,
487   procBindClause,
488   reductionClause,
489   nowaitClause,
490   linearClause,
491   scheduleClause,
492   collapseClause,
493   orderClause,
494   orderedClause,
495   memoryOrderClause,
496   hintClause,
497   COUNT
498 };
499 
500 //===----------------------------------------------------------------------===//
501 // Parser for Clause List
502 //===----------------------------------------------------------------------===//
503 
504 /// Parse a list of clauses. The clauses can appear in any order, but their
505 /// operand segment indices are in the same order that they are passed in the
506 /// `clauses` list. The operand segments are added over the prevSegments
507 
508 /// clause-list ::= clause clause-list | empty
509 /// clause ::= if | num-threads | private | firstprivate | lastprivate |
510 ///            shared | copyin | allocate | default | proc-bind | reduction |
511 ///            nowait | linear | schedule | collapse | order | ordered |
512 ///            inclusive
513 /// if ::= `if` `(` ssa-id-and-type `)`
514 /// num-threads ::= `num_threads` `(` ssa-id-and-type `)`
515 /// private ::= `private` operand-and-type-list
516 /// firstprivate ::= `firstprivate` operand-and-type-list
517 /// lastprivate ::= `lastprivate` operand-and-type-list
518 /// shared ::= `shared` operand-and-type-list
519 /// copyin ::= `copyin` operand-and-type-list
520 /// allocate ::= `allocate` `(` allocate-operand-list `)`
521 /// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`)
522 /// proc-bind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)`
523 /// reduction ::= `reduction` `(` reduction-entry-list `)`
524 /// nowait ::= `nowait`
525 /// linear ::= `linear` `(` linear-list `)`
526 /// schedule ::= `schedule` `(` sched-list `)`
527 /// collapse ::= `collapse` `(` ssa-id-and-type `)`
528 /// order ::= `order` `(` `concurrent` `)`
529 /// ordered ::= `ordered` `(` ssa-id-and-type `)`
530 /// inclusive ::= `inclusive`
531 ///
532 /// Note that each clause can only appear once in the clase-list.
533 static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
534                                 SmallVectorImpl<ClauseType> &clauses,
535                                 SmallVectorImpl<int> &segments) {
536 
537   // Check done[clause] to see if it has been parsed already
538   llvm::BitVector done(ClauseType::COUNT, false);
539 
540   // See pos[clause] to get position of clause in operand segments
541   SmallVector<int> pos(ClauseType::COUNT, -1);
542 
543   // Stores the last parsed clause keyword
544   StringRef clauseKeyword;
545   StringRef opName = result.name.getStringRef();
546 
547   // Containers for storing operands, types and attributes for various clauses
548   std::pair<OpAsmParser::OperandType, Type> ifCond;
549   std::pair<OpAsmParser::OperandType, Type> numThreads;
550 
551   SmallVector<OpAsmParser::OperandType> privates, firstprivates, lastprivates,
552       shareds, copyins;
553   SmallVector<Type> privateTypes, firstprivateTypes, lastprivateTypes,
554       sharedTypes, copyinTypes;
555 
556   SmallVector<OpAsmParser::OperandType> allocates, allocators;
557   SmallVector<Type> allocateTypes, allocatorTypes;
558 
559   SmallVector<SymbolRefAttr> reductionSymbols;
560   SmallVector<OpAsmParser::OperandType> reductionVars;
561   SmallVector<Type> reductionVarTypes;
562 
563   SmallVector<OpAsmParser::OperandType> linears;
564   SmallVector<Type> linearTypes;
565   SmallVector<OpAsmParser::OperandType> linearSteps;
566 
567   SmallString<8> schedule;
568   SmallVector<SmallString<12>> modifiers;
569   Optional<OpAsmParser::OperandType> scheduleChunkSize;
570 
571   // Compute the position of clauses in operand segments
572   int currPos = 0;
573   for (ClauseType clause : clauses) {
574 
575     // Skip the following clauses - they do not take any position in operand
576     // segments
577     if (clause == defaultClause || clause == procBindClause ||
578         clause == nowaitClause || clause == collapseClause ||
579         clause == orderClause || clause == orderedClause)
580       continue;
581 
582     pos[clause] = currPos++;
583 
584     // For the following clauses, two positions are reserved in the operand
585     // segments
586     if (clause == allocateClause || clause == linearClause)
587       currPos++;
588   }
589 
590   SmallVector<int> clauseSegments(currPos);
591 
592   // Helper function to check if a clause is allowed/repeated or not
593   auto checkAllowed = [&](ClauseType clause,
594                           bool allowRepeat = false) -> ParseResult {
595     if (!llvm::is_contained(clauses, clause))
596       return parser.emitError(parser.getCurrentLocation())
597              << clauseKeyword << " is not a valid clause for the " << opName
598              << " operation";
599     if (done[clause] && !allowRepeat)
600       return parser.emitError(parser.getCurrentLocation())
601              << "at most one " << clauseKeyword << " clause can appear on the "
602              << opName << " operation";
603     done[clause] = true;
604     return success();
605   };
606 
607   while (succeeded(parser.parseOptionalKeyword(&clauseKeyword))) {
608     if (clauseKeyword == "if") {
609       if (checkAllowed(ifClause) || parser.parseLParen() ||
610           parser.parseOperand(ifCond.first) ||
611           parser.parseColonType(ifCond.second) || parser.parseRParen())
612         return failure();
613       clauseSegments[pos[ifClause]] = 1;
614     } else if (clauseKeyword == "num_threads") {
615       if (checkAllowed(numThreadsClause) || parser.parseLParen() ||
616           parser.parseOperand(numThreads.first) ||
617           parser.parseColonType(numThreads.second) || parser.parseRParen())
618         return failure();
619       clauseSegments[pos[numThreadsClause]] = 1;
620     } else if (clauseKeyword == "private") {
621       if (checkAllowed(privateClause) ||
622           parseOperandAndTypeList(parser, privates, privateTypes))
623         return failure();
624       clauseSegments[pos[privateClause]] = privates.size();
625     } else if (clauseKeyword == "firstprivate") {
626       if (checkAllowed(firstprivateClause) ||
627           parseOperandAndTypeList(parser, firstprivates, firstprivateTypes))
628         return failure();
629       clauseSegments[pos[firstprivateClause]] = firstprivates.size();
630     } else if (clauseKeyword == "lastprivate") {
631       if (checkAllowed(lastprivateClause) ||
632           parseOperandAndTypeList(parser, lastprivates, lastprivateTypes))
633         return failure();
634       clauseSegments[pos[lastprivateClause]] = lastprivates.size();
635     } else if (clauseKeyword == "shared") {
636       if (checkAllowed(sharedClause) ||
637           parseOperandAndTypeList(parser, shareds, sharedTypes))
638         return failure();
639       clauseSegments[pos[sharedClause]] = shareds.size();
640     } else if (clauseKeyword == "copyin") {
641       if (checkAllowed(copyinClause) ||
642           parseOperandAndTypeList(parser, copyins, copyinTypes))
643         return failure();
644       clauseSegments[pos[copyinClause]] = copyins.size();
645     } else if (clauseKeyword == "allocate") {
646       if (checkAllowed(allocateClause) ||
647           parseAllocateAndAllocator(parser, allocates, allocateTypes,
648                                     allocators, allocatorTypes))
649         return failure();
650       clauseSegments[pos[allocateClause]] = allocates.size();
651       clauseSegments[pos[allocateClause] + 1] = allocators.size();
652     } else if (clauseKeyword == "default") {
653       StringRef defval;
654       if (checkAllowed(defaultClause) || parser.parseLParen() ||
655           parser.parseKeyword(&defval) || parser.parseRParen())
656         return failure();
657       // The def prefix is required for the attribute as "private" is a keyword
658       // in C++.
659       auto attr = parser.getBuilder().getStringAttr("def" + defval);
660       result.addAttribute("default_val", attr);
661     } else if (clauseKeyword == "proc_bind") {
662       StringRef bind;
663       if (checkAllowed(procBindClause) || parser.parseLParen() ||
664           parser.parseKeyword(&bind) || parser.parseRParen())
665         return failure();
666       auto attr = parser.getBuilder().getStringAttr(bind);
667       result.addAttribute("proc_bind_val", attr);
668     } else if (clauseKeyword == "reduction") {
669       if (checkAllowed(reductionClause) ||
670           parseReductionVarList(parser, reductionSymbols, reductionVars,
671                                 reductionVarTypes))
672         return failure();
673       clauseSegments[pos[reductionClause]] = reductionVars.size();
674     } else if (clauseKeyword == "nowait") {
675       if (checkAllowed(nowaitClause))
676         return failure();
677       auto attr = UnitAttr::get(parser.getBuilder().getContext());
678       result.addAttribute("nowait", attr);
679     } else if (clauseKeyword == "linear") {
680       if (checkAllowed(linearClause) ||
681           parseLinearClause(parser, linears, linearTypes, linearSteps))
682         return failure();
683       clauseSegments[pos[linearClause]] = linears.size();
684       clauseSegments[pos[linearClause] + 1] = linearSteps.size();
685     } else if (clauseKeyword == "schedule") {
686       if (checkAllowed(scheduleClause) ||
687           parseScheduleClause(parser, schedule, modifiers, scheduleChunkSize))
688         return failure();
689       if (scheduleChunkSize) {
690         clauseSegments[pos[scheduleClause]] = 1;
691       }
692     } else if (clauseKeyword == "collapse") {
693       auto type = parser.getBuilder().getI64Type();
694       mlir::IntegerAttr attr;
695       if (checkAllowed(collapseClause) || parser.parseLParen() ||
696           parser.parseAttribute(attr, type) || parser.parseRParen())
697         return failure();
698       result.addAttribute("collapse_val", attr);
699     } else if (clauseKeyword == "ordered") {
700       mlir::IntegerAttr attr;
701       if (checkAllowed(orderedClause))
702         return failure();
703       if (succeeded(parser.parseOptionalLParen())) {
704         auto type = parser.getBuilder().getI64Type();
705         if (parser.parseAttribute(attr, type) || parser.parseRParen())
706           return failure();
707       } else {
708         // Use 0 to represent no ordered parameter was specified
709         attr = parser.getBuilder().getI64IntegerAttr(0);
710       }
711       result.addAttribute("ordered_val", attr);
712     } else if (clauseKeyword == "order") {
713       StringRef order;
714       if (checkAllowed(orderClause) || parser.parseLParen() ||
715           parser.parseKeyword(&order) || parser.parseRParen())
716         return failure();
717       auto attr = parser.getBuilder().getStringAttr(order);
718       result.addAttribute("order_val", attr);
719     } else if (clauseKeyword == "memory_order") {
720       StringRef memoryOrder;
721       if (checkAllowed(memoryOrderClause) || parser.parseLParen() ||
722           parser.parseKeyword(&memoryOrder) || parser.parseRParen())
723         return failure();
724       result.addAttribute("memory_order",
725                           parser.getBuilder().getStringAttr(memoryOrder));
726     } else if (clauseKeyword == "hint") {
727       IntegerAttr hint;
728       if (checkAllowed(hintClause) ||
729           parseSynchronizationHint(parser, hint, false))
730         return failure();
731       result.addAttribute("hint", hint);
732     } else {
733       return parser.emitError(parser.getNameLoc())
734              << clauseKeyword << " is not a valid clause";
735     }
736   }
737 
738   // Add if parameter.
739   if (done[ifClause] && clauseSegments[pos[ifClause]] &&
740       failed(
741           parser.resolveOperand(ifCond.first, ifCond.second, result.operands)))
742     return failure();
743 
744   // Add num_threads parameter.
745   if (done[numThreadsClause] && clauseSegments[pos[numThreadsClause]] &&
746       failed(parser.resolveOperand(numThreads.first, numThreads.second,
747                                    result.operands)))
748     return failure();
749 
750   // Add private parameters.
751   if (done[privateClause] && clauseSegments[pos[privateClause]] &&
752       failed(parser.resolveOperands(privates, privateTypes,
753                                     privates[0].location, result.operands)))
754     return failure();
755 
756   // Add firstprivate parameters.
757   if (done[firstprivateClause] && clauseSegments[pos[firstprivateClause]] &&
758       failed(parser.resolveOperands(firstprivates, firstprivateTypes,
759                                     firstprivates[0].location,
760                                     result.operands)))
761     return failure();
762 
763   // Add lastprivate parameters.
764   if (done[lastprivateClause] && clauseSegments[pos[lastprivateClause]] &&
765       failed(parser.resolveOperands(lastprivates, lastprivateTypes,
766                                     lastprivates[0].location, result.operands)))
767     return failure();
768 
769   // Add shared parameters.
770   if (done[sharedClause] && clauseSegments[pos[sharedClause]] &&
771       failed(parser.resolveOperands(shareds, sharedTypes, shareds[0].location,
772                                     result.operands)))
773     return failure();
774 
775   // Add copyin parameters.
776   if (done[copyinClause] && clauseSegments[pos[copyinClause]] &&
777       failed(parser.resolveOperands(copyins, copyinTypes, copyins[0].location,
778                                     result.operands)))
779     return failure();
780 
781   // Add allocate parameters.
782   if (done[allocateClause] && clauseSegments[pos[allocateClause]] &&
783       failed(parser.resolveOperands(allocates, allocateTypes,
784                                     allocates[0].location, result.operands)))
785     return failure();
786 
787   // Add allocator parameters.
788   if (done[allocateClause] && clauseSegments[pos[allocateClause] + 1] &&
789       failed(parser.resolveOperands(allocators, allocatorTypes,
790                                     allocators[0].location, result.operands)))
791     return failure();
792 
793   // Add reduction parameters and symbols
794   if (done[reductionClause] && clauseSegments[pos[reductionClause]]) {
795     if (failed(parser.resolveOperands(reductionVars, reductionVarTypes,
796                                       parser.getNameLoc(), result.operands)))
797       return failure();
798 
799     SmallVector<Attribute> reductions(reductionSymbols.begin(),
800                                       reductionSymbols.end());
801     result.addAttribute("reductions",
802                         parser.getBuilder().getArrayAttr(reductions));
803   }
804 
805   // Add linear parameters
806   if (done[linearClause] && clauseSegments[pos[linearClause]]) {
807     auto linearStepType = parser.getBuilder().getI32Type();
808     SmallVector<Type> linearStepTypes(linearSteps.size(), linearStepType);
809     if (failed(parser.resolveOperands(linears, linearTypes, linears[0].location,
810                                       result.operands)) ||
811         failed(parser.resolveOperands(linearSteps, linearStepTypes,
812                                       linearSteps[0].location,
813                                       result.operands)))
814       return failure();
815   }
816 
817   // Add schedule parameters
818   if (done[scheduleClause] && !schedule.empty()) {
819     schedule[0] = llvm::toUpper(schedule[0]);
820     auto attr = parser.getBuilder().getStringAttr(schedule);
821     result.addAttribute("schedule_val", attr);
822     if (modifiers.size() > 0) {
823       auto mod = parser.getBuilder().getStringAttr(modifiers[0]);
824       result.addAttribute("schedule_modifier", mod);
825     }
826     if (scheduleChunkSize) {
827       auto chunkSizeType = parser.getBuilder().getI32Type();
828       parser.resolveOperand(*scheduleChunkSize, chunkSizeType, result.operands);
829     }
830   }
831 
832   segments.insert(segments.end(), clauseSegments.begin(), clauseSegments.end());
833 
834   return success();
835 }
836 
837 /// Parses a parallel operation.
838 ///
839 /// operation ::= `omp.parallel` clause-list
840 /// clause-list ::= clause | clause clause-list
841 /// clause ::= if | num-threads | private | firstprivate | shared | copyin |
842 ///            allocate | default | proc-bind
843 ///
844 static ParseResult parseParallelOp(OpAsmParser &parser,
845                                    OperationState &result) {
846   SmallVector<ClauseType> clauses = {
847       ifClause,           numThreadsClause, privateClause,
848       firstprivateClause, sharedClause,     copyinClause,
849       allocateClause,     defaultClause,    procBindClause};
850 
851   SmallVector<int> segments;
852 
853   if (failed(parseClauses(parser, result, clauses, segments)))
854     return failure();
855 
856   result.addAttribute("operand_segment_sizes",
857                       parser.getBuilder().getI32VectorAttr(segments));
858 
859   Region *body = result.addRegion();
860   SmallVector<OpAsmParser::OperandType> regionArgs;
861   SmallVector<Type> regionArgTypes;
862   if (parser.parseRegion(*body, regionArgs, regionArgTypes))
863     return failure();
864   return success();
865 }
866 
867 //===----------------------------------------------------------------------===//
868 // Parser, printer and verifier for SectionsOp
869 //===----------------------------------------------------------------------===//
870 
871 /// Parses an OpenMP Sections operation
872 ///
873 /// sections ::= `omp.sections` clause-list
874 /// clause-list ::= clause clause-list | empty
875 /// clause ::= private | firstprivate | lastprivate | reduction | allocate |
876 ///            nowait
877 static ParseResult parseSectionsOp(OpAsmParser &parser,
878                                    OperationState &result) {
879 
880   SmallVector<ClauseType> clauses = {privateClause,     firstprivateClause,
881                                      lastprivateClause, reductionClause,
882                                      allocateClause,    nowaitClause};
883 
884   SmallVector<int> segments;
885 
886   if (failed(parseClauses(parser, result, clauses, segments)))
887     return failure();
888 
889   result.addAttribute("operand_segment_sizes",
890                       parser.getBuilder().getI32VectorAttr(segments));
891 
892   // Now parse the body.
893   Region *body = result.addRegion();
894   if (parser.parseRegion(*body))
895     return failure();
896   return success();
897 }
898 
899 static void printSectionsOp(OpAsmPrinter &p, SectionsOp op) {
900   p << " ";
901   printDataVars(p, op.private_vars(), "private");
902   printDataVars(p, op.firstprivate_vars(), "firstprivate");
903   printDataVars(p, op.lastprivate_vars(), "lastprivate");
904 
905   if (!op.reduction_vars().empty())
906     printReductionVarList(p, op.reductions(), op.reduction_vars());
907 
908   if (!op.allocate_vars().empty())
909     printAllocateAndAllocator(p, op.allocate_vars(), op.allocators_vars());
910 
911   if (op.nowait())
912     p << "nowait ";
913 
914   p.printRegion(op.region());
915 }
916 
917 static LogicalResult verifySectionsOp(SectionsOp op) {
918 
919   // A list item may not appear in more than one clause on the same directive,
920   // except that it may be specified in both firstprivate and lastprivate
921   // clauses.
922   for (auto var : op.private_vars()) {
923     if (llvm::is_contained(op.firstprivate_vars(), var))
924       return op.emitOpError()
925              << "operand used in both private and firstprivate clauses";
926     if (llvm::is_contained(op.lastprivate_vars(), var))
927       return op.emitOpError()
928              << "operand used in both private and lastprivate clauses";
929   }
930 
931   if (op.allocate_vars().size() != op.allocators_vars().size())
932     return op.emitError(
933         "expected equal sizes for allocate and allocator variables");
934 
935   for (auto &inst : *op.region().begin()) {
936     if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst)))
937       op.emitOpError()
938           << "expected omp.section op or terminator op inside region";
939   }
940 
941   return verifyReductionVarList(op, op.reductions(), op.reduction_vars());
942 }
943 
944 /// Parses an OpenMP Workshare Loop operation
945 ///
946 /// wsloop ::= `omp.wsloop` loop-control clause-list
947 /// loop-control ::= `(` ssa-id-list `)` `:` type `=`  loop-bounds
948 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` inclusive? steps
949 /// steps := `step` `(`ssa-id-list`)`
950 /// clause-list ::= clause clause-list | empty
951 /// clause ::= private | firstprivate | lastprivate | linear | schedule |
952 //             collapse | nowait | ordered | order | reduction
953 static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) {
954 
955   // Parse an opening `(` followed by induction variables followed by `)`
956   SmallVector<OpAsmParser::OperandType> ivs;
957   if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
958                                      OpAsmParser::Delimiter::Paren))
959     return failure();
960 
961   int numIVs = static_cast<int>(ivs.size());
962   Type loopVarType;
963   if (parser.parseColonType(loopVarType))
964     return failure();
965 
966   // Parse loop bounds.
967   SmallVector<OpAsmParser::OperandType> lower;
968   if (parser.parseEqual() ||
969       parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) ||
970       parser.resolveOperands(lower, loopVarType, result.operands))
971     return failure();
972 
973   SmallVector<OpAsmParser::OperandType> upper;
974   if (parser.parseKeyword("to") ||
975       parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) ||
976       parser.resolveOperands(upper, loopVarType, result.operands))
977     return failure();
978 
979   if (succeeded(parser.parseOptionalKeyword("inclusive"))) {
980     auto attr = UnitAttr::get(parser.getBuilder().getContext());
981     result.addAttribute("inclusive", attr);
982   }
983 
984   // Parse step values.
985   SmallVector<OpAsmParser::OperandType> steps;
986   if (parser.parseKeyword("step") ||
987       parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) ||
988       parser.resolveOperands(steps, loopVarType, result.operands))
989     return failure();
990 
991   SmallVector<ClauseType> clauses = {
992       privateClause,   firstprivateClause, lastprivateClause, linearClause,
993       reductionClause, collapseClause,     orderClause,       orderedClause,
994       nowaitClause,    scheduleClause};
995   SmallVector<int> segments{numIVs, numIVs, numIVs};
996   if (failed(parseClauses(parser, result, clauses, segments)))
997     return failure();
998 
999   result.addAttribute("operand_segment_sizes",
1000                       parser.getBuilder().getI32VectorAttr(segments));
1001 
1002   // Now parse the body.
1003   Region *body = result.addRegion();
1004   SmallVector<Type> ivTypes(numIVs, loopVarType);
1005   SmallVector<OpAsmParser::OperandType> blockArgs(ivs);
1006   if (parser.parseRegion(*body, blockArgs, ivTypes))
1007     return failure();
1008   return success();
1009 }
1010 
1011 static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) {
1012   auto args = op.getRegion().front().getArguments();
1013   p << " (" << args << ") : " << args[0].getType() << " = (" << op.lowerBound()
1014     << ") to (" << op.upperBound() << ") ";
1015   if (op.inclusive()) {
1016     p << "inclusive ";
1017   }
1018   p << "step (" << op.step() << ") ";
1019 
1020   printDataVars(p, op.private_vars(), "private");
1021   printDataVars(p, op.firstprivate_vars(), "firstprivate");
1022   printDataVars(p, op.lastprivate_vars(), "lastprivate");
1023 
1024   if (op.linear_vars().size())
1025     printLinearClause(p, op.linear_vars(), op.linear_step_vars());
1026 
1027   if (auto sched = op.schedule_val())
1028     printScheduleClause(p, sched.getValue(), op.schedule_modifier(),
1029                         op.schedule_chunk_var());
1030 
1031   if (auto collapse = op.collapse_val())
1032     p << "collapse(" << collapse << ") ";
1033 
1034   if (op.nowait())
1035     p << "nowait ";
1036 
1037   if (auto ordered = op.ordered_val())
1038     p << "ordered(" << ordered << ") ";
1039 
1040   if (auto order = op.order_val())
1041     p << "order(" << order << ") ";
1042 
1043   if (!op.reduction_vars().empty())
1044     printReductionVarList(p, op.reductions(), op.reduction_vars());
1045 
1046   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
1047 }
1048 
1049 //===----------------------------------------------------------------------===//
1050 // ReductionOp
1051 //===----------------------------------------------------------------------===//
1052 
1053 static ParseResult parseAtomicReductionRegion(OpAsmParser &parser,
1054                                               Region &region) {
1055   if (parser.parseOptionalKeyword("atomic"))
1056     return success();
1057   return parser.parseRegion(region);
1058 }
1059 
1060 static void printAtomicReductionRegion(OpAsmPrinter &printer,
1061                                        ReductionDeclareOp op, Region &region) {
1062   if (region.empty())
1063     return;
1064   printer << "atomic ";
1065   printer.printRegion(region);
1066 }
1067 
1068 static LogicalResult verifyReductionDeclareOp(ReductionDeclareOp op) {
1069   if (op.initializerRegion().empty())
1070     return op.emitOpError() << "expects non-empty initializer region";
1071   Block &initializerEntryBlock = op.initializerRegion().front();
1072   if (initializerEntryBlock.getNumArguments() != 1 ||
1073       initializerEntryBlock.getArgument(0).getType() != op.type()) {
1074     return op.emitOpError() << "expects initializer region with one argument "
1075                                "of the reduction type";
1076   }
1077 
1078   for (YieldOp yieldOp : op.initializerRegion().getOps<YieldOp>()) {
1079     if (yieldOp.results().size() != 1 ||
1080         yieldOp.results().getTypes()[0] != op.type())
1081       return op.emitOpError() << "expects initializer region to yield a value "
1082                                  "of the reduction type";
1083   }
1084 
1085   if (op.reductionRegion().empty())
1086     return op.emitOpError() << "expects non-empty reduction region";
1087   Block &reductionEntryBlock = op.reductionRegion().front();
1088   if (reductionEntryBlock.getNumArguments() != 2 ||
1089       reductionEntryBlock.getArgumentTypes()[0] !=
1090           reductionEntryBlock.getArgumentTypes()[1] ||
1091       reductionEntryBlock.getArgumentTypes()[0] != op.type())
1092     return op.emitOpError() << "expects reduction region with two arguments of "
1093                                "the reduction type";
1094   for (YieldOp yieldOp : op.reductionRegion().getOps<YieldOp>()) {
1095     if (yieldOp.results().size() != 1 ||
1096         yieldOp.results().getTypes()[0] != op.type())
1097       return op.emitOpError() << "expects reduction region to yield a value "
1098                                  "of the reduction type";
1099   }
1100 
1101   if (op.atomicReductionRegion().empty())
1102     return success();
1103 
1104   Block &atomicReductionEntryBlock = op.atomicReductionRegion().front();
1105   if (atomicReductionEntryBlock.getNumArguments() != 2 ||
1106       atomicReductionEntryBlock.getArgumentTypes()[0] !=
1107           atomicReductionEntryBlock.getArgumentTypes()[1])
1108     return op.emitOpError() << "expects atomic reduction region with two "
1109                                "arguments of the same type";
1110   auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0]
1111                      .dyn_cast<PointerLikeType>();
1112   if (!ptrType || ptrType.getElementType() != op.type())
1113     return op.emitOpError() << "expects atomic reduction region arguments to "
1114                                "be accumulators containing the reduction type";
1115   return success();
1116 }
1117 
1118 static LogicalResult verifyReductionOp(ReductionOp op) {
1119   // TODO: generalize this to an op interface when there is more than one op
1120   // that supports reductions.
1121   auto container = op->getParentOfType<WsLoopOp>();
1122   for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i)
1123     if (container.reduction_vars()[i] == op.accumulator())
1124       return success();
1125 
1126   return op.emitOpError() << "the accumulator is not used by the parent";
1127 }
1128 
1129 //===----------------------------------------------------------------------===//
1130 // WsLoopOp
1131 //===----------------------------------------------------------------------===//
1132 
1133 void WsLoopOp::build(OpBuilder &builder, OperationState &state,
1134                      ValueRange lowerBound, ValueRange upperBound,
1135                      ValueRange step, ArrayRef<NamedAttribute> attributes) {
1136   build(builder, state, TypeRange(), lowerBound, upperBound, step,
1137         /*private_vars=*/ValueRange(),
1138         /*firstprivate_vars=*/ValueRange(), /*lastprivate_vars=*/ValueRange(),
1139         /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
1140         /*reduction_vars=*/ValueRange(), /*schedule_val=*/nullptr,
1141         /*schedule_chunk_var=*/nullptr, /*collapse_val=*/nullptr,
1142         /*nowait=*/nullptr, /*ordered_val=*/nullptr, /*order_val=*/nullptr,
1143         /*inclusive=*/nullptr, /*buildBody=*/false);
1144   state.addAttributes(attributes);
1145 }
1146 
1147 void WsLoopOp::build(OpBuilder &, OperationState &state, TypeRange resultTypes,
1148                      ValueRange operands, ArrayRef<NamedAttribute> attributes) {
1149   state.addOperands(operands);
1150   state.addAttributes(attributes);
1151   (void)state.addRegion();
1152   assert(resultTypes.empty() && "mismatched number of return types");
1153   state.addTypes(resultTypes);
1154 }
1155 
1156 void WsLoopOp::build(OpBuilder &builder, OperationState &result,
1157                      TypeRange typeRange, ValueRange lowerBounds,
1158                      ValueRange upperBounds, ValueRange steps,
1159                      ValueRange privateVars, ValueRange firstprivateVars,
1160                      ValueRange lastprivateVars, ValueRange linearVars,
1161                      ValueRange linearStepVars, ValueRange reductionVars,
1162                      StringAttr scheduleVal, Value scheduleChunkVar,
1163                      IntegerAttr collapseVal, UnitAttr nowait,
1164                      IntegerAttr orderedVal, StringAttr orderVal,
1165                      UnitAttr inclusive, bool buildBody) {
1166   result.addOperands(lowerBounds);
1167   result.addOperands(upperBounds);
1168   result.addOperands(steps);
1169   result.addOperands(privateVars);
1170   result.addOperands(firstprivateVars);
1171   result.addOperands(linearVars);
1172   result.addOperands(linearStepVars);
1173   if (scheduleChunkVar)
1174     result.addOperands(scheduleChunkVar);
1175 
1176   if (scheduleVal)
1177     result.addAttribute("schedule_val", scheduleVal);
1178   if (collapseVal)
1179     result.addAttribute("collapse_val", collapseVal);
1180   if (nowait)
1181     result.addAttribute("nowait", nowait);
1182   if (orderedVal)
1183     result.addAttribute("ordered_val", orderedVal);
1184   if (orderVal)
1185     result.addAttribute("order", orderVal);
1186   if (inclusive)
1187     result.addAttribute("inclusive", inclusive);
1188   result.addAttribute(
1189       WsLoopOp::getOperandSegmentSizeAttr(),
1190       builder.getI32VectorAttr(
1191           {static_cast<int32_t>(lowerBounds.size()),
1192            static_cast<int32_t>(upperBounds.size()),
1193            static_cast<int32_t>(steps.size()),
1194            static_cast<int32_t>(privateVars.size()),
1195            static_cast<int32_t>(firstprivateVars.size()),
1196            static_cast<int32_t>(lastprivateVars.size()),
1197            static_cast<int32_t>(linearVars.size()),
1198            static_cast<int32_t>(linearStepVars.size()),
1199            static_cast<int32_t>(reductionVars.size()),
1200            static_cast<int32_t>(scheduleChunkVar != nullptr ? 1 : 0)}));
1201 
1202   Region *bodyRegion = result.addRegion();
1203   if (buildBody) {
1204     OpBuilder::InsertionGuard guard(builder);
1205     unsigned numIVs = steps.size();
1206     SmallVector<Type, 8> argTypes(numIVs, steps.getType().front());
1207     builder.createBlock(bodyRegion, {}, argTypes);
1208   }
1209 }
1210 
1211 static LogicalResult verifyWsLoopOp(WsLoopOp op) {
1212   return verifyReductionVarList(op, op.reductions(), op.reduction_vars());
1213 }
1214 
1215 //===----------------------------------------------------------------------===//
1216 // Verifier for critical construct (2.17.1)
1217 //===----------------------------------------------------------------------===//
1218 
1219 static LogicalResult verifyCriticalDeclareOp(CriticalDeclareOp op) {
1220   return verifySynchronizationHint(op, op.hint());
1221 }
1222 
1223 static LogicalResult verifyCriticalOp(CriticalOp op) {
1224 
1225   if (op.nameAttr()) {
1226     auto symbolRef = op.nameAttr().cast<SymbolRefAttr>();
1227     auto decl =
1228         SymbolTable::lookupNearestSymbolFrom<CriticalDeclareOp>(op, symbolRef);
1229     if (!decl) {
1230       return op.emitOpError() << "expected symbol reference " << symbolRef
1231                               << " to point to a critical declaration";
1232     }
1233   }
1234 
1235   return success();
1236 }
1237 
1238 //===----------------------------------------------------------------------===//
1239 // Verifier for ordered construct
1240 //===----------------------------------------------------------------------===//
1241 
1242 static LogicalResult verifyOrderedOp(OrderedOp op) {
1243   auto container = op->getParentOfType<WsLoopOp>();
1244   if (!container || !container.ordered_valAttr() ||
1245       container.ordered_valAttr().getInt() == 0)
1246     return op.emitOpError() << "ordered depend directive must be closely "
1247                             << "nested inside a worksharing-loop with ordered "
1248                             << "clause with parameter present";
1249 
1250   if (container.ordered_valAttr().getInt() !=
1251       (int64_t)op.num_loops_val().getValue())
1252     return op.emitOpError() << "number of variables in depend clause does not "
1253                             << "match number of iteration variables in the "
1254                             << "doacross loop";
1255 
1256   return success();
1257 }
1258 
1259 static LogicalResult verifyOrderedRegionOp(OrderedRegionOp op) {
1260   // TODO: The code generation for ordered simd directive is not supported yet.
1261   if (op.simd())
1262     return failure();
1263 
1264   if (auto container = op->getParentOfType<WsLoopOp>()) {
1265     if (!container.ordered_valAttr() ||
1266         container.ordered_valAttr().getInt() != 0)
1267       return op.emitOpError() << "ordered region must be closely nested inside "
1268                               << "a worksharing-loop region with an ordered "
1269                               << "clause without parameter present";
1270   }
1271 
1272   return success();
1273 }
1274 
1275 //===----------------------------------------------------------------------===//
1276 // AtomicReadOp
1277 //===----------------------------------------------------------------------===//
1278 
1279 /// Parser for AtomicReadOp
1280 ///
1281 /// operation ::= `omp.atomic.read` atomic-clause-list address `->` result-type
1282 /// address ::= operand `:` type
1283 static ParseResult parseAtomicReadOp(OpAsmParser &parser,
1284                                      OperationState &result) {
1285   OpAsmParser::OperandType address;
1286   Type addressType;
1287   SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
1288   SmallVector<int> segments;
1289 
1290   if (parser.parseOperand(address) ||
1291       parseClauses(parser, result, clauses, segments) ||
1292       parser.parseColonType(addressType) ||
1293       parser.resolveOperand(address, addressType, result.operands))
1294     return failure();
1295 
1296   SmallVector<Type> resultType;
1297   if (parser.parseArrowTypeList(resultType))
1298     return failure();
1299   result.addTypes(resultType);
1300   return success();
1301 }
1302 
1303 /// Printer for AtomicReadOp
1304 static void printAtomicReadOp(OpAsmPrinter &p, AtomicReadOp op) {
1305   p << " " << op.address() << " ";
1306   if (op.memory_order())
1307     p << "memory_order(" << op.memory_order().getValue() << ") ";
1308   if (op.hintAttr())
1309     printSynchronizationHint(p << " ", op, op.hintAttr());
1310   p << ": " << op.address().getType() << " -> " << op.getType();
1311   return;
1312 }
1313 
1314 /// Verifier for AtomicReadOp
1315 static LogicalResult verifyAtomicReadOp(AtomicReadOp op) {
1316   if (op.memory_order()) {
1317     StringRef memOrder = op.memory_order().getValue();
1318     if (memOrder.equals("acq_rel") || memOrder.equals("release"))
1319       return op.emitError(
1320           "memory-order must not be acq_rel or release for atomic reads");
1321   }
1322   return verifySynchronizationHint(op, op.hint());
1323 }
1324 
1325 //===----------------------------------------------------------------------===//
1326 // AtomicWriteOp
1327 //===----------------------------------------------------------------------===//
1328 
1329 /// Parser for AtomicWriteOp
1330 ///
1331 /// operation ::= `omp.atomic.write` atomic-clause-list operands
1332 /// operands ::= address `,` value
1333 /// address ::= operand `:` type
1334 /// value ::= operand `:` type
1335 static ParseResult parseAtomicWriteOp(OpAsmParser &parser,
1336                                       OperationState &result) {
1337   OpAsmParser::OperandType address, value;
1338   Type addrType, valueType;
1339   SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
1340   SmallVector<int> segments;
1341 
1342   if (parser.parseOperand(address) || parser.parseComma() ||
1343       parser.parseOperand(value) ||
1344       parseClauses(parser, result, clauses, segments) ||
1345       parser.parseColonType(addrType) || parser.parseComma() ||
1346       parser.parseType(valueType) ||
1347       parser.resolveOperand(address, addrType, result.operands) ||
1348       parser.resolveOperand(value, valueType, result.operands))
1349     return failure();
1350   return success();
1351 }
1352 
1353 /// Printer for AtomicWriteOp
1354 static void printAtomicWriteOp(OpAsmPrinter &p, AtomicWriteOp op) {
1355   p << " " << op.address() << ", " << op.value() << " ";
1356   if (op.memory_order())
1357     p << "memory_order(" << op.memory_order() << ") ";
1358   if (op.hintAttr())
1359     printSynchronizationHint(p, op, op.hintAttr());
1360   p << ": " << op.address().getType() << ", " << op.value().getType();
1361   return;
1362 }
1363 
1364 /// Verifier for AtomicWriteOp
1365 static LogicalResult verifyAtomicWriteOp(AtomicWriteOp op) {
1366   if (op.memory_order()) {
1367     StringRef memoryOrder = op.memory_order().getValue();
1368     if (memoryOrder.equals("acq_rel") || memoryOrder.equals("acquire"))
1369       return op.emitError(
1370           "memory-order must not be acq_rel or acquire for atomic writes");
1371   }
1372   return verifySynchronizationHint(op, op.hint());
1373 }
1374 
1375 #define GET_OP_CLASSES
1376 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
1377