1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the OpenMP dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/OpImplementation.h"
18 #include "mlir/IR/OperationSupport.h"
19 
20 #include "llvm/ADT/BitVector.h"
21 #include "llvm/ADT/SmallString.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/ADT/StringSwitch.h"
25 #include <cstddef>
26 
27 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
28 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
29 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
30 
31 using namespace mlir;
32 using namespace mlir::omp;
33 
34 namespace {
35 /// Model for pointer-like types that already provide a `getElementType` method.
36 template <typename T>
37 struct PointerLikeModel
38     : public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> {
39   Type getElementType(Type pointer) const {
40     return pointer.cast<T>().getElementType();
41   }
42 };
43 } // end namespace
44 
45 void OpenMPDialect::initialize() {
46   addOperations<
47 #define GET_OP_LIST
48 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
49       >();
50 
51   LLVM::LLVMPointerType::attachInterface<
52       PointerLikeModel<LLVM::LLVMPointerType>>(*getContext());
53   MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext());
54 }
55 
56 //===----------------------------------------------------------------------===//
57 // ParallelOp
58 //===----------------------------------------------------------------------===//
59 
60 void ParallelOp::build(OpBuilder &builder, OperationState &state,
61                        ArrayRef<NamedAttribute> attributes) {
62   ParallelOp::build(
63       builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
64       /*default_val=*/nullptr, /*private_vars=*/ValueRange(),
65       /*firstprivate_vars=*/ValueRange(), /*shared_vars=*/ValueRange(),
66       /*copyin_vars=*/ValueRange(), /*allocate_vars=*/ValueRange(),
67       /*allocators_vars=*/ValueRange(), /*proc_bind_val=*/nullptr);
68   state.addAttributes(attributes);
69 }
70 
71 //===----------------------------------------------------------------------===//
72 // Parser and printer for Operand and type list
73 //===----------------------------------------------------------------------===//
74 
75 /// Parse a list of operands with types.
76 ///
77 /// operand-and-type-list ::= `(` ssa-id-and-type-list `)`
78 /// ssa-id-and-type-list ::= ssa-id-and-type |
79 ///                          ssa-id-and-type `,` ssa-id-and-type-list
80 /// ssa-id-and-type ::= ssa-id `:` type
81 static ParseResult
82 parseOperandAndTypeList(OpAsmParser &parser,
83                         SmallVectorImpl<OpAsmParser::OperandType> &operands,
84                         SmallVectorImpl<Type> &types) {
85   return parser.parseCommaSeparatedList(
86       OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
87         OpAsmParser::OperandType operand;
88         Type type;
89         if (parser.parseOperand(operand) || parser.parseColonType(type))
90           return failure();
91         operands.push_back(operand);
92         types.push_back(type);
93         return success();
94       });
95 }
96 
97 /// Print an operand and type list with parentheses
98 static void printOperandAndTypeList(OpAsmPrinter &p, OperandRange operands) {
99   p << "(";
100   llvm::interleaveComma(
101       operands, p, [&](const Value &v) { p << v << " : " << v.getType(); });
102   p << ") ";
103 }
104 
105 /// Print data variables corresponding to a data-sharing clause `name`
106 static void printDataVars(OpAsmPrinter &p, OperandRange operands,
107                           StringRef name) {
108   if (operands.size()) {
109     p << name;
110     printOperandAndTypeList(p, operands);
111   }
112 }
113 
114 //===----------------------------------------------------------------------===//
115 // Parser and printer for Allocate Clause
116 //===----------------------------------------------------------------------===//
117 
118 /// Parse an allocate clause with allocators and a list of operands with types.
119 ///
120 /// allocate ::= `allocate` `(` allocate-operand-list `)`
121 /// allocate-operand-list :: = allocate-operand |
122 ///                            allocator-operand `,` allocate-operand-list
123 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
124 /// ssa-id-and-type ::= ssa-id `:` type
125 static ParseResult parseAllocateAndAllocator(
126     OpAsmParser &parser,
127     SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate,
128     SmallVectorImpl<Type> &typesAllocate,
129     SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator,
130     SmallVectorImpl<Type> &typesAllocator) {
131 
132   return parser.parseCommaSeparatedList(
133       OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
134         OpAsmParser::OperandType operand;
135         Type type;
136         if (parser.parseOperand(operand) || parser.parseColonType(type))
137           return failure();
138         operandsAllocator.push_back(operand);
139         typesAllocator.push_back(type);
140         if (parser.parseArrow())
141           return failure();
142         if (parser.parseOperand(operand) || parser.parseColonType(type))
143           return failure();
144 
145         operandsAllocate.push_back(operand);
146         typesAllocate.push_back(type);
147         return success();
148       });
149 }
150 
151 /// Print allocate clause
152 static void printAllocateAndAllocator(OpAsmPrinter &p,
153                                       OperandRange varsAllocate,
154                                       OperandRange varsAllocator) {
155   if (varsAllocate.empty())
156     return;
157 
158   p << "allocate(";
159   for (unsigned i = 0; i < varsAllocate.size(); ++i) {
160     std::string separator = i == varsAllocate.size() - 1 ? ") " : ", ";
161     p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> ";
162     p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator;
163   }
164 }
165 
166 static LogicalResult verifyParallelOp(ParallelOp op) {
167   if (op.allocate_vars().size() != op.allocators_vars().size())
168     return op.emitError(
169         "expected equal sizes for allocate and allocator variables");
170   return success();
171 }
172 
173 static void printParallelOp(OpAsmPrinter &p, ParallelOp op) {
174   p << " ";
175   if (auto ifCond = op.if_expr_var())
176     p << "if(" << ifCond << " : " << ifCond.getType() << ") ";
177 
178   if (auto threads = op.num_threads_var())
179     p << "num_threads(" << threads << " : " << threads.getType() << ") ";
180 
181   printDataVars(p, op.private_vars(), "private");
182   printDataVars(p, op.firstprivate_vars(), "firstprivate");
183   printDataVars(p, op.shared_vars(), "shared");
184   printDataVars(p, op.copyin_vars(), "copyin");
185   printAllocateAndAllocator(p, op.allocate_vars(), op.allocators_vars());
186 
187   if (auto def = op.default_val())
188     p << "default(" << def->drop_front(3) << ") ";
189 
190   if (auto bind = op.proc_bind_val())
191     p << "proc_bind(" << bind << ") ";
192 
193   p.printRegion(op.getRegion());
194 }
195 
196 //===----------------------------------------------------------------------===//
197 // Parser and printer for Linear Clause
198 //===----------------------------------------------------------------------===//
199 
200 /// linear ::= `linear` `(` linear-list `)`
201 /// linear-list := linear-val | linear-val linear-list
202 /// linear-val := ssa-id-and-type `=` ssa-id-and-type
203 static ParseResult
204 parseLinearClause(OpAsmParser &parser,
205                   SmallVectorImpl<OpAsmParser::OperandType> &vars,
206                   SmallVectorImpl<Type> &types,
207                   SmallVectorImpl<OpAsmParser::OperandType> &stepVars) {
208   if (parser.parseLParen())
209     return failure();
210 
211   do {
212     OpAsmParser::OperandType var;
213     Type type;
214     OpAsmParser::OperandType stepVar;
215     if (parser.parseOperand(var) || parser.parseEqual() ||
216         parser.parseOperand(stepVar) || parser.parseColonType(type))
217       return failure();
218 
219     vars.push_back(var);
220     types.push_back(type);
221     stepVars.push_back(stepVar);
222   } while (succeeded(parser.parseOptionalComma()));
223 
224   if (parser.parseRParen())
225     return failure();
226 
227   return success();
228 }
229 
230 /// Print Linear Clause
231 static void printLinearClause(OpAsmPrinter &p, OperandRange linearVars,
232                               OperandRange linearStepVars) {
233   size_t linearVarsSize = linearVars.size();
234   p << "(";
235   for (unsigned i = 0; i < linearVarsSize; ++i) {
236     std::string separator = i == linearVarsSize - 1 ? ") " : ", ";
237     p << linearVars[i];
238     if (linearStepVars.size() > i)
239       p << " = " << linearStepVars[i];
240     p << " : " << linearVars[i].getType() << separator;
241   }
242 }
243 
244 //===----------------------------------------------------------------------===//
245 // Parser and printer for Schedule Clause
246 //===----------------------------------------------------------------------===//
247 
248 /// schedule ::= `schedule` `(` sched-list `)`
249 /// sched-list ::= sched-val | sched-val sched-list
250 /// sched-val ::= sched-with-chunk | sched-wo-chunk
251 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
252 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
253 /// sched-wo-chunk ::=  `auto` | `runtime`
254 static ParseResult
255 parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule,
256                     SmallVectorImpl<SmallString<12>> &modifiers,
257                     Optional<OpAsmParser::OperandType> &chunkSize) {
258   if (parser.parseLParen())
259     return failure();
260 
261   StringRef keyword;
262   if (parser.parseKeyword(&keyword))
263     return failure();
264 
265   schedule = keyword;
266   if (keyword == "static" || keyword == "dynamic" || keyword == "guided") {
267     if (succeeded(parser.parseOptionalEqual())) {
268       chunkSize = OpAsmParser::OperandType{};
269       if (parser.parseOperand(*chunkSize))
270         return failure();
271     } else {
272       chunkSize = llvm::NoneType::None;
273     }
274   } else if (keyword == "auto" || keyword == "runtime") {
275     chunkSize = llvm::NoneType::None;
276   } else {
277     return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
278   }
279 
280   // If there is a comma, we have one or more modifiers..
281   if (succeeded(parser.parseOptionalComma())) {
282     StringRef mod;
283     if (parser.parseKeyword(&mod))
284       return failure();
285     modifiers.push_back(mod);
286   }
287 
288   if (parser.parseRParen())
289     return failure();
290 
291   return success();
292 }
293 
294 /// Print schedule clause
295 static void printScheduleClause(OpAsmPrinter &p, StringRef &sched,
296                                 llvm::Optional<StringRef> modifier,
297                                 Value scheduleChunkVar) {
298   std::string schedLower = sched.lower();
299   p << "(" << schedLower;
300   if (scheduleChunkVar)
301     p << " = " << scheduleChunkVar;
302   if (modifier && modifier.getValue() != "none")
303     p << ", " << modifier;
304   p << ") ";
305 }
306 
307 //===----------------------------------------------------------------------===//
308 // Parser, printer and verifier for ReductionVarList
309 //===----------------------------------------------------------------------===//
310 
311 /// reduction ::= `reduction` `(` reduction-entry-list `)`
312 /// reduction-entry-list ::= reduction-entry
313 ///                        | reduction-entry-list `,` reduction-entry
314 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type
315 static ParseResult
316 parseReductionVarList(OpAsmParser &parser,
317                       SmallVectorImpl<SymbolRefAttr> &symbols,
318                       SmallVectorImpl<OpAsmParser::OperandType> &operands,
319                       SmallVectorImpl<Type> &types) {
320   if (failed(parser.parseLParen()))
321     return failure();
322 
323   do {
324     if (parser.parseAttribute(symbols.emplace_back()) || parser.parseArrow() ||
325         parser.parseOperand(operands.emplace_back()) ||
326         parser.parseColonType(types.emplace_back()))
327       return failure();
328   } while (succeeded(parser.parseOptionalComma()));
329   return parser.parseRParen();
330 }
331 
332 /// Print Reduction clause
333 static void printReductionVarList(OpAsmPrinter &p,
334                                   Optional<ArrayAttr> reductions,
335                                   OperandRange reduction_vars) {
336   for (unsigned i = 0, e = reductions->size(); i < e; ++i) {
337     if (i != 0)
338       p << ", ";
339     p << (*reductions)[i] << " -> " << reduction_vars[i] << " : "
340       << reduction_vars[i].getType();
341   }
342   p << ") ";
343 }
344 
345 /// Verifies Reduction Clause
346 static LogicalResult verifyReductionVarList(Operation *op,
347                                             Optional<ArrayAttr> reductions,
348                                             OperandRange reduction_vars) {
349   if (reduction_vars.size() != 0) {
350     if (!reductions || reductions->size() != reduction_vars.size())
351       return op->emitOpError()
352              << "expected as many reduction symbol references "
353                 "as reduction variables";
354   } else {
355     if (reductions)
356       return op->emitOpError() << "unexpected reduction symbol references";
357     return success();
358   }
359 
360   DenseSet<Value> accumulators;
361   for (auto args : llvm::zip(reduction_vars, *reductions)) {
362     Value accum = std::get<0>(args);
363 
364     if (!accumulators.insert(accum).second)
365       return op->emitOpError() << "accumulator variable used more than once";
366 
367     Type varType = accum.getType().cast<PointerLikeType>();
368     auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>();
369     auto decl =
370         SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef);
371     if (!decl)
372       return op->emitOpError() << "expected symbol reference " << symbolRef
373                                << " to point to a reduction declaration";
374 
375     if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
376       return op->emitOpError()
377              << "expected accumulator (" << varType
378              << ") to be the same type as reduction declaration ("
379              << decl.getAccumulatorType() << ")";
380   }
381 
382   return success();
383 }
384 
385 //===----------------------------------------------------------------------===//
386 // Parser, printer and verifier for Synchronization Hint (2.17.12)
387 //===----------------------------------------------------------------------===//
388 
389 /// Parses a Synchronization Hint clause. The value of hint is an integer
390 /// which is a combination of different hints from `omp_sync_hint_t`.
391 ///
392 /// hint-clause = `hint` `(` hint-value `)`
393 static ParseResult parseSynchronizationHint(OpAsmParser &parser,
394                                             IntegerAttr &hintAttr,
395                                             bool parseKeyword = true) {
396   if (parseKeyword && failed(parser.parseOptionalKeyword("hint"))) {
397     hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
398     return success();
399   }
400 
401   if (failed(parser.parseLParen()))
402     return failure();
403   StringRef hintKeyword;
404   int64_t hint = 0;
405   do {
406     if (failed(parser.parseKeyword(&hintKeyword)))
407       return failure();
408     if (hintKeyword == "uncontended")
409       hint |= 1;
410     else if (hintKeyword == "contended")
411       hint |= 2;
412     else if (hintKeyword == "nonspeculative")
413       hint |= 4;
414     else if (hintKeyword == "speculative")
415       hint |= 8;
416     else
417       return parser.emitError(parser.getCurrentLocation())
418              << hintKeyword << " is not a valid hint";
419   } while (succeeded(parser.parseOptionalComma()));
420   if (failed(parser.parseRParen()))
421     return failure();
422   hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
423   return success();
424 }
425 
426 /// Prints a Synchronization Hint clause
427 static void printSynchronizationHint(OpAsmPrinter &p, Operation *op,
428                                      IntegerAttr hintAttr) {
429   int64_t hint = hintAttr.getInt();
430 
431   if (hint == 0)
432     return;
433 
434   // Helper function to get n-th bit from the right end of `value`
435   auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
436 
437   bool uncontended = bitn(hint, 0);
438   bool contended = bitn(hint, 1);
439   bool nonspeculative = bitn(hint, 2);
440   bool speculative = bitn(hint, 3);
441 
442   SmallVector<StringRef> hints;
443   if (uncontended)
444     hints.push_back("uncontended");
445   if (contended)
446     hints.push_back("contended");
447   if (nonspeculative)
448     hints.push_back("nonspeculative");
449   if (speculative)
450     hints.push_back("speculative");
451 
452   p << "hint(";
453   llvm::interleaveComma(hints, p);
454   p << ") ";
455 }
456 
457 /// Verifies a synchronization hint clause
458 static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
459 
460   // Helper function to get n-th bit from the right end of `value`
461   auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
462 
463   bool uncontended = bitn(hint, 0);
464   bool contended = bitn(hint, 1);
465   bool nonspeculative = bitn(hint, 2);
466   bool speculative = bitn(hint, 3);
467 
468   if (uncontended && contended)
469     return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
470                                 "omp_sync_hint_contended cannot be combined";
471   if (nonspeculative && speculative)
472     return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
473                                 "omp_sync_hint_speculative cannot be combined.";
474   return success();
475 }
476 
477 enum ClauseType {
478   ifClause,
479   numThreadsClause,
480   privateClause,
481   firstprivateClause,
482   lastprivateClause,
483   sharedClause,
484   copyinClause,
485   allocateClause,
486   defaultClause,
487   procBindClause,
488   reductionClause,
489   nowaitClause,
490   linearClause,
491   scheduleClause,
492   collapseClause,
493   orderClause,
494   orderedClause,
495   memoryOrderClause,
496   hintClause,
497   COUNT
498 };
499 
500 //===----------------------------------------------------------------------===//
501 // Parser for Clause List
502 //===----------------------------------------------------------------------===//
503 
504 /// Parse a list of clauses. The clauses can appear in any order, but their
505 /// operand segment indices are in the same order that they are passed in the
506 /// `clauses` list. The operand segments are added over the prevSegments
507 
508 /// clause-list ::= clause clause-list | empty
509 /// clause ::= if | num-threads | private | firstprivate | lastprivate |
510 ///            shared | copyin | allocate | default | proc-bind | reduction |
511 ///            nowait | linear | schedule | collapse | order | ordered |
512 ///            inclusive
513 /// if ::= `if` `(` ssa-id-and-type `)`
514 /// num-threads ::= `num_threads` `(` ssa-id-and-type `)`
515 /// private ::= `private` operand-and-type-list
516 /// firstprivate ::= `firstprivate` operand-and-type-list
517 /// lastprivate ::= `lastprivate` operand-and-type-list
518 /// shared ::= `shared` operand-and-type-list
519 /// copyin ::= `copyin` operand-and-type-list
520 /// allocate ::= `allocate` `(` allocate-operand-list `)`
521 /// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`)
522 /// proc-bind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)`
523 /// reduction ::= `reduction` `(` reduction-entry-list `)`
524 /// nowait ::= `nowait`
525 /// linear ::= `linear` `(` linear-list `)`
526 /// schedule ::= `schedule` `(` sched-list `)`
527 /// collapse ::= `collapse` `(` ssa-id-and-type `)`
528 /// order ::= `order` `(` `concurrent` `)`
529 /// ordered ::= `ordered` `(` ssa-id-and-type `)`
530 /// inclusive ::= `inclusive`
531 ///
532 /// Note that each clause can only appear once in the clase-list.
533 static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
534                                 SmallVectorImpl<ClauseType> &clauses,
535                                 SmallVectorImpl<int> &segments) {
536 
537   // Check done[clause] to see if it has been parsed already
538   llvm::BitVector done(ClauseType::COUNT, false);
539 
540   // See pos[clause] to get position of clause in operand segments
541   SmallVector<int> pos(ClauseType::COUNT, -1);
542 
543   // Stores the last parsed clause keyword
544   StringRef clauseKeyword;
545   StringRef opName = result.name.getStringRef();
546 
547   // Containers for storing operands, types and attributes for various clauses
548   std::pair<OpAsmParser::OperandType, Type> ifCond;
549   std::pair<OpAsmParser::OperandType, Type> numThreads;
550 
551   SmallVector<OpAsmParser::OperandType> privates, firstprivates, lastprivates,
552       shareds, copyins;
553   SmallVector<Type> privateTypes, firstprivateTypes, lastprivateTypes,
554       sharedTypes, copyinTypes;
555 
556   SmallVector<OpAsmParser::OperandType> allocates, allocators;
557   SmallVector<Type> allocateTypes, allocatorTypes;
558 
559   SmallVector<SymbolRefAttr> reductionSymbols;
560   SmallVector<OpAsmParser::OperandType> reductionVars;
561   SmallVector<Type> reductionVarTypes;
562 
563   SmallVector<OpAsmParser::OperandType> linears;
564   SmallVector<Type> linearTypes;
565   SmallVector<OpAsmParser::OperandType> linearSteps;
566 
567   SmallString<8> schedule;
568   SmallVector<SmallString<12>> modifiers;
569   Optional<OpAsmParser::OperandType> scheduleChunkSize;
570 
571   // Compute the position of clauses in operand segments
572   int currPos = 0;
573   for (ClauseType clause : clauses) {
574 
575     // Skip the following clauses - they do not take any position in operand
576     // segments
577     if (clause == defaultClause || clause == procBindClause ||
578         clause == nowaitClause || clause == collapseClause ||
579         clause == orderClause || clause == orderedClause)
580       continue;
581 
582     pos[clause] = currPos++;
583 
584     // For the following clauses, two positions are reserved in the operand
585     // segments
586     if (clause == allocateClause || clause == linearClause)
587       currPos++;
588   }
589 
590   SmallVector<int> clauseSegments(currPos);
591 
592   // Helper function to check if a clause is allowed/repeated or not
593   auto checkAllowed = [&](ClauseType clause,
594                           bool allowRepeat = false) -> ParseResult {
595     if (!llvm::is_contained(clauses, clause))
596       return parser.emitError(parser.getCurrentLocation())
597              << clauseKeyword << " is not a valid clause for the " << opName
598              << " operation";
599     if (done[clause] && !allowRepeat)
600       return parser.emitError(parser.getCurrentLocation())
601              << "at most one " << clauseKeyword << " clause can appear on the "
602              << opName << " operation";
603     done[clause] = true;
604     return success();
605   };
606 
607   while (succeeded(parser.parseOptionalKeyword(&clauseKeyword))) {
608     if (clauseKeyword == "if") {
609       if (checkAllowed(ifClause) || parser.parseLParen() ||
610           parser.parseOperand(ifCond.first) ||
611           parser.parseColonType(ifCond.second) || parser.parseRParen())
612         return failure();
613       clauseSegments[pos[ifClause]] = 1;
614     } else if (clauseKeyword == "num_threads") {
615       if (checkAllowed(numThreadsClause) || parser.parseLParen() ||
616           parser.parseOperand(numThreads.first) ||
617           parser.parseColonType(numThreads.second) || parser.parseRParen())
618         return failure();
619       clauseSegments[pos[numThreadsClause]] = 1;
620     } else if (clauseKeyword == "private") {
621       if (checkAllowed(privateClause) ||
622           parseOperandAndTypeList(parser, privates, privateTypes))
623         return failure();
624       clauseSegments[pos[privateClause]] = privates.size();
625     } else if (clauseKeyword == "firstprivate") {
626       if (checkAllowed(firstprivateClause) ||
627           parseOperandAndTypeList(parser, firstprivates, firstprivateTypes))
628         return failure();
629       clauseSegments[pos[firstprivateClause]] = firstprivates.size();
630     } else if (clauseKeyword == "lastprivate") {
631       if (checkAllowed(lastprivateClause) ||
632           parseOperandAndTypeList(parser, lastprivates, lastprivateTypes))
633         return failure();
634       clauseSegments[pos[lastprivateClause]] = lastprivates.size();
635     } else if (clauseKeyword == "shared") {
636       if (checkAllowed(sharedClause) ||
637           parseOperandAndTypeList(parser, shareds, sharedTypes))
638         return failure();
639       clauseSegments[pos[sharedClause]] = shareds.size();
640     } else if (clauseKeyword == "copyin") {
641       if (checkAllowed(copyinClause) ||
642           parseOperandAndTypeList(parser, copyins, copyinTypes))
643         return failure();
644       clauseSegments[pos[copyinClause]] = copyins.size();
645     } else if (clauseKeyword == "allocate") {
646       if (checkAllowed(allocateClause) ||
647           parseAllocateAndAllocator(parser, allocates, allocateTypes,
648                                     allocators, allocatorTypes))
649         return failure();
650       clauseSegments[pos[allocateClause]] = allocates.size();
651       clauseSegments[pos[allocateClause] + 1] = allocators.size();
652     } else if (clauseKeyword == "default") {
653       StringRef defval;
654       if (checkAllowed(defaultClause) || parser.parseLParen() ||
655           parser.parseKeyword(&defval) || parser.parseRParen())
656         return failure();
657       // The def prefix is required for the attribute as "private" is a keyword
658       // in C++.
659       auto attr = parser.getBuilder().getStringAttr("def" + defval);
660       result.addAttribute("default_val", attr);
661     } else if (clauseKeyword == "proc_bind") {
662       StringRef bind;
663       if (checkAllowed(procBindClause) || parser.parseLParen() ||
664           parser.parseKeyword(&bind) || parser.parseRParen())
665         return failure();
666       auto attr = parser.getBuilder().getStringAttr(bind);
667       result.addAttribute("proc_bind_val", attr);
668     } else if (clauseKeyword == "reduction") {
669       if (checkAllowed(reductionClause) ||
670           parseReductionVarList(parser, reductionSymbols, reductionVars,
671                                 reductionVarTypes))
672         return failure();
673       clauseSegments[pos[reductionClause]] = reductionVars.size();
674     } else if (clauseKeyword == "nowait") {
675       if (checkAllowed(nowaitClause))
676         return failure();
677       auto attr = UnitAttr::get(parser.getBuilder().getContext());
678       result.addAttribute("nowait", attr);
679     } else if (clauseKeyword == "linear") {
680       if (checkAllowed(linearClause) ||
681           parseLinearClause(parser, linears, linearTypes, linearSteps))
682         return failure();
683       clauseSegments[pos[linearClause]] = linears.size();
684       clauseSegments[pos[linearClause] + 1] = linearSteps.size();
685     } else if (clauseKeyword == "schedule") {
686       if (checkAllowed(scheduleClause) ||
687           parseScheduleClause(parser, schedule, modifiers, scheduleChunkSize))
688         return failure();
689       if (scheduleChunkSize) {
690         clauseSegments[pos[scheduleClause]] = 1;
691       }
692     } else if (clauseKeyword == "collapse") {
693       auto type = parser.getBuilder().getI64Type();
694       mlir::IntegerAttr attr;
695       if (checkAllowed(collapseClause) || parser.parseLParen() ||
696           parser.parseAttribute(attr, type) || parser.parseRParen())
697         return failure();
698       result.addAttribute("collapse_val", attr);
699     } else if (clauseKeyword == "ordered") {
700       mlir::IntegerAttr attr;
701       if (checkAllowed(orderedClause))
702         return failure();
703       if (succeeded(parser.parseOptionalLParen())) {
704         auto type = parser.getBuilder().getI64Type();
705         if (parser.parseAttribute(attr, type) || parser.parseRParen())
706           return failure();
707       } else {
708         // Use 0 to represent no ordered parameter was specified
709         attr = parser.getBuilder().getI64IntegerAttr(0);
710       }
711       result.addAttribute("ordered_val", attr);
712     } else if (clauseKeyword == "order") {
713       StringRef order;
714       if (checkAllowed(orderClause) || parser.parseLParen() ||
715           parser.parseKeyword(&order) || parser.parseRParen())
716         return failure();
717       auto attr = parser.getBuilder().getStringAttr(order);
718       result.addAttribute("order_val", attr);
719     } else if (clauseKeyword == "memory_order") {
720       StringRef memoryOrder;
721       if (checkAllowed(memoryOrderClause) || parser.parseLParen() ||
722           parser.parseKeyword(&memoryOrder) || parser.parseRParen())
723         return failure();
724       result.addAttribute("memory_order",
725                           parser.getBuilder().getStringAttr(memoryOrder));
726     } else if (clauseKeyword == "hint") {
727       IntegerAttr hint;
728       if (checkAllowed(hintClause) ||
729           parseSynchronizationHint(parser, hint, false))
730         return failure();
731       result.addAttribute("hint", hint);
732     } else {
733       return parser.emitError(parser.getNameLoc())
734              << clauseKeyword << " is not a valid clause";
735     }
736   }
737 
738   // Add if parameter.
739   if (done[ifClause] && clauseSegments[pos[ifClause]] &&
740       failed(
741           parser.resolveOperand(ifCond.first, ifCond.second, result.operands)))
742     return failure();
743 
744   // Add num_threads parameter.
745   if (done[numThreadsClause] && clauseSegments[pos[numThreadsClause]] &&
746       failed(parser.resolveOperand(numThreads.first, numThreads.second,
747                                    result.operands)))
748     return failure();
749 
750   // Add private parameters.
751   if (done[privateClause] && clauseSegments[pos[privateClause]] &&
752       failed(parser.resolveOperands(privates, privateTypes,
753                                     privates[0].location, result.operands)))
754     return failure();
755 
756   // Add firstprivate parameters.
757   if (done[firstprivateClause] && clauseSegments[pos[firstprivateClause]] &&
758       failed(parser.resolveOperands(firstprivates, firstprivateTypes,
759                                     firstprivates[0].location,
760                                     result.operands)))
761     return failure();
762 
763   // Add lastprivate parameters.
764   if (done[lastprivateClause] && clauseSegments[pos[lastprivateClause]] &&
765       failed(parser.resolveOperands(lastprivates, lastprivateTypes,
766                                     lastprivates[0].location, result.operands)))
767     return failure();
768 
769   // Add shared parameters.
770   if (done[sharedClause] && clauseSegments[pos[sharedClause]] &&
771       failed(parser.resolveOperands(shareds, sharedTypes, shareds[0].location,
772                                     result.operands)))
773     return failure();
774 
775   // Add copyin parameters.
776   if (done[copyinClause] && clauseSegments[pos[copyinClause]] &&
777       failed(parser.resolveOperands(copyins, copyinTypes, copyins[0].location,
778                                     result.operands)))
779     return failure();
780 
781   // Add allocate parameters.
782   if (done[allocateClause] && clauseSegments[pos[allocateClause]] &&
783       failed(parser.resolveOperands(allocates, allocateTypes,
784                                     allocates[0].location, result.operands)))
785     return failure();
786 
787   // Add allocator parameters.
788   if (done[allocateClause] && clauseSegments[pos[allocateClause] + 1] &&
789       failed(parser.resolveOperands(allocators, allocatorTypes,
790                                     allocators[0].location, result.operands)))
791     return failure();
792 
793   // Add reduction parameters and symbols
794   if (done[reductionClause] && clauseSegments[pos[reductionClause]]) {
795     if (failed(parser.resolveOperands(reductionVars, reductionVarTypes,
796                                       parser.getNameLoc(), result.operands)))
797       return failure();
798 
799     SmallVector<Attribute> reductions(reductionSymbols.begin(),
800                                       reductionSymbols.end());
801     result.addAttribute("reductions",
802                         parser.getBuilder().getArrayAttr(reductions));
803   }
804 
805   // Add linear parameters
806   if (done[linearClause] && clauseSegments[pos[linearClause]]) {
807     auto linearStepType = parser.getBuilder().getI32Type();
808     SmallVector<Type> linearStepTypes(linearSteps.size(), linearStepType);
809     if (failed(parser.resolveOperands(linears, linearTypes, linears[0].location,
810                                       result.operands)) ||
811         failed(parser.resolveOperands(linearSteps, linearStepTypes,
812                                       linearSteps[0].location,
813                                       result.operands)))
814       return failure();
815   }
816 
817   // Add schedule parameters
818   if (done[scheduleClause] && !schedule.empty()) {
819     schedule[0] = llvm::toUpper(schedule[0]);
820     auto attr = parser.getBuilder().getStringAttr(schedule);
821     result.addAttribute("schedule_val", attr);
822     if (modifiers.size() > 0) {
823       auto mod = parser.getBuilder().getStringAttr(modifiers[0]);
824       result.addAttribute("schedule_modifier", mod);
825     }
826     if (scheduleChunkSize) {
827       auto chunkSizeType = parser.getBuilder().getI32Type();
828       parser.resolveOperand(*scheduleChunkSize, chunkSizeType, result.operands);
829     }
830   }
831 
832   segments.insert(segments.end(), clauseSegments.begin(), clauseSegments.end());
833 
834   return success();
835 }
836 
837 /// Parses a parallel operation.
838 ///
839 /// operation ::= `omp.parallel` clause-list
840 /// clause-list ::= clause | clause clause-list
841 /// clause ::= if | num-threads | private | firstprivate | shared | copyin |
842 ///            allocate | default | proc-bind
843 ///
844 static ParseResult parseParallelOp(OpAsmParser &parser,
845                                    OperationState &result) {
846   SmallVector<ClauseType> clauses = {
847       ifClause,           numThreadsClause, privateClause,
848       firstprivateClause, sharedClause,     copyinClause,
849       allocateClause,     defaultClause,    procBindClause};
850 
851   SmallVector<int> segments;
852 
853   if (failed(parseClauses(parser, result, clauses, segments)))
854     return failure();
855 
856   result.addAttribute("operand_segment_sizes",
857                       parser.getBuilder().getI32VectorAttr(segments));
858 
859   Region *body = result.addRegion();
860   SmallVector<OpAsmParser::OperandType> regionArgs;
861   SmallVector<Type> regionArgTypes;
862   if (parser.parseRegion(*body, regionArgs, regionArgTypes))
863     return failure();
864   return success();
865 }
866 
867 /// Parses an OpenMP Workshare Loop operation
868 ///
869 /// wsloop ::= `omp.wsloop` loop-control clause-list
870 /// loop-control ::= `(` ssa-id-list `)` `:` type `=`  loop-bounds
871 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` inclusive? steps
872 /// steps := `step` `(`ssa-id-list`)`
873 /// clause-list ::= clause clause-list | empty
874 /// clause ::= private | firstprivate | lastprivate | linear | schedule |
875 //             collapse | nowait | ordered | order | reduction
876 static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) {
877 
878   // Parse an opening `(` followed by induction variables followed by `)`
879   SmallVector<OpAsmParser::OperandType> ivs;
880   if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
881                                      OpAsmParser::Delimiter::Paren))
882     return failure();
883 
884   int numIVs = static_cast<int>(ivs.size());
885   Type loopVarType;
886   if (parser.parseColonType(loopVarType))
887     return failure();
888 
889   // Parse loop bounds.
890   SmallVector<OpAsmParser::OperandType> lower;
891   if (parser.parseEqual() ||
892       parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) ||
893       parser.resolveOperands(lower, loopVarType, result.operands))
894     return failure();
895 
896   SmallVector<OpAsmParser::OperandType> upper;
897   if (parser.parseKeyword("to") ||
898       parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) ||
899       parser.resolveOperands(upper, loopVarType, result.operands))
900     return failure();
901 
902   if (succeeded(parser.parseOptionalKeyword("inclusive"))) {
903     auto attr = UnitAttr::get(parser.getBuilder().getContext());
904     result.addAttribute("inclusive", attr);
905   }
906 
907   // Parse step values.
908   SmallVector<OpAsmParser::OperandType> steps;
909   if (parser.parseKeyword("step") ||
910       parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) ||
911       parser.resolveOperands(steps, loopVarType, result.operands))
912     return failure();
913 
914   SmallVector<ClauseType> clauses = {
915       privateClause,   firstprivateClause, lastprivateClause, linearClause,
916       reductionClause, collapseClause,     orderClause,       orderedClause,
917       nowaitClause,    scheduleClause};
918   SmallVector<int> segments{numIVs, numIVs, numIVs};
919   if (failed(parseClauses(parser, result, clauses, segments)))
920     return failure();
921 
922   result.addAttribute("operand_segment_sizes",
923                       parser.getBuilder().getI32VectorAttr(segments));
924 
925   // Now parse the body.
926   Region *body = result.addRegion();
927   SmallVector<Type> ivTypes(numIVs, loopVarType);
928   SmallVector<OpAsmParser::OperandType> blockArgs(ivs);
929   if (parser.parseRegion(*body, blockArgs, ivTypes))
930     return failure();
931   return success();
932 }
933 
934 static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) {
935   auto args = op.getRegion().front().getArguments();
936   p << " (" << args << ") : " << args[0].getType() << " = (" << op.lowerBound()
937     << ") to (" << op.upperBound() << ") ";
938   if (op.inclusive()) {
939     p << "inclusive ";
940   }
941   p << "step (" << op.step() << ") ";
942 
943   printDataVars(p, op.private_vars(), "private");
944   printDataVars(p, op.firstprivate_vars(), "firstprivate");
945   printDataVars(p, op.lastprivate_vars(), "lastprivate");
946 
947   if (op.linear_vars().size()) {
948     p << "linear";
949     printLinearClause(p, op.linear_vars(), op.linear_step_vars());
950   }
951 
952   if (auto sched = op.schedule_val()) {
953     p << "schedule";
954     printScheduleClause(p, sched.getValue(), op.schedule_modifier(),
955                         op.schedule_chunk_var());
956   }
957 
958   if (auto collapse = op.collapse_val())
959     p << "collapse(" << collapse << ") ";
960 
961   if (op.nowait())
962     p << "nowait ";
963 
964   if (auto ordered = op.ordered_val())
965     p << "ordered(" << ordered << ") ";
966 
967   if (auto order = op.order_val())
968     p << "order(" << order << ") ";
969 
970   if (!op.reduction_vars().empty()) {
971     p << "reduction(";
972     printReductionVarList(p, op.reductions(), op.reduction_vars());
973   }
974 
975   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
976 }
977 
978 //===----------------------------------------------------------------------===//
979 // ReductionOp
980 //===----------------------------------------------------------------------===//
981 
982 static ParseResult parseAtomicReductionRegion(OpAsmParser &parser,
983                                               Region &region) {
984   if (parser.parseOptionalKeyword("atomic"))
985     return success();
986   return parser.parseRegion(region);
987 }
988 
989 static void printAtomicReductionRegion(OpAsmPrinter &printer,
990                                        ReductionDeclareOp op, Region &region) {
991   if (region.empty())
992     return;
993   printer << "atomic ";
994   printer.printRegion(region);
995 }
996 
997 static LogicalResult verifyReductionDeclareOp(ReductionDeclareOp op) {
998   if (op.initializerRegion().empty())
999     return op.emitOpError() << "expects non-empty initializer region";
1000   Block &initializerEntryBlock = op.initializerRegion().front();
1001   if (initializerEntryBlock.getNumArguments() != 1 ||
1002       initializerEntryBlock.getArgument(0).getType() != op.type()) {
1003     return op.emitOpError() << "expects initializer region with one argument "
1004                                "of the reduction type";
1005   }
1006 
1007   for (YieldOp yieldOp : op.initializerRegion().getOps<YieldOp>()) {
1008     if (yieldOp.results().size() != 1 ||
1009         yieldOp.results().getTypes()[0] != op.type())
1010       return op.emitOpError() << "expects initializer region to yield a value "
1011                                  "of the reduction type";
1012   }
1013 
1014   if (op.reductionRegion().empty())
1015     return op.emitOpError() << "expects non-empty reduction region";
1016   Block &reductionEntryBlock = op.reductionRegion().front();
1017   if (reductionEntryBlock.getNumArguments() != 2 ||
1018       reductionEntryBlock.getArgumentTypes()[0] !=
1019           reductionEntryBlock.getArgumentTypes()[1] ||
1020       reductionEntryBlock.getArgumentTypes()[0] != op.type())
1021     return op.emitOpError() << "expects reduction region with two arguments of "
1022                                "the reduction type";
1023   for (YieldOp yieldOp : op.reductionRegion().getOps<YieldOp>()) {
1024     if (yieldOp.results().size() != 1 ||
1025         yieldOp.results().getTypes()[0] != op.type())
1026       return op.emitOpError() << "expects reduction region to yield a value "
1027                                  "of the reduction type";
1028   }
1029 
1030   if (op.atomicReductionRegion().empty())
1031     return success();
1032 
1033   Block &atomicReductionEntryBlock = op.atomicReductionRegion().front();
1034   if (atomicReductionEntryBlock.getNumArguments() != 2 ||
1035       atomicReductionEntryBlock.getArgumentTypes()[0] !=
1036           atomicReductionEntryBlock.getArgumentTypes()[1])
1037     return op.emitOpError() << "expects atomic reduction region with two "
1038                                "arguments of the same type";
1039   auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0]
1040                      .dyn_cast<PointerLikeType>();
1041   if (!ptrType || ptrType.getElementType() != op.type())
1042     return op.emitOpError() << "expects atomic reduction region arguments to "
1043                                "be accumulators containing the reduction type";
1044   return success();
1045 }
1046 
1047 static LogicalResult verifyReductionOp(ReductionOp op) {
1048   // TODO: generalize this to an op interface when there is more than one op
1049   // that supports reductions.
1050   auto container = op->getParentOfType<WsLoopOp>();
1051   for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i)
1052     if (container.reduction_vars()[i] == op.accumulator())
1053       return success();
1054 
1055   return op.emitOpError() << "the accumulator is not used by the parent";
1056 }
1057 
1058 //===----------------------------------------------------------------------===//
1059 // WsLoopOp
1060 //===----------------------------------------------------------------------===//
1061 
1062 void WsLoopOp::build(OpBuilder &builder, OperationState &state,
1063                      ValueRange lowerBound, ValueRange upperBound,
1064                      ValueRange step, ArrayRef<NamedAttribute> attributes) {
1065   build(builder, state, TypeRange(), lowerBound, upperBound, step,
1066         /*private_vars=*/ValueRange(),
1067         /*firstprivate_vars=*/ValueRange(), /*lastprivate_vars=*/ValueRange(),
1068         /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
1069         /*reduction_vars=*/ValueRange(), /*schedule_val=*/nullptr,
1070         /*schedule_chunk_var=*/nullptr, /*collapse_val=*/nullptr,
1071         /*nowait=*/nullptr, /*ordered_val=*/nullptr, /*order_val=*/nullptr,
1072         /*inclusive=*/nullptr, /*buildBody=*/false);
1073   state.addAttributes(attributes);
1074 }
1075 
1076 void WsLoopOp::build(OpBuilder &, OperationState &state, TypeRange resultTypes,
1077                      ValueRange operands, ArrayRef<NamedAttribute> attributes) {
1078   state.addOperands(operands);
1079   state.addAttributes(attributes);
1080   (void)state.addRegion();
1081   assert(resultTypes.empty() && "mismatched number of return types");
1082   state.addTypes(resultTypes);
1083 }
1084 
1085 void WsLoopOp::build(OpBuilder &builder, OperationState &result,
1086                      TypeRange typeRange, ValueRange lowerBounds,
1087                      ValueRange upperBounds, ValueRange steps,
1088                      ValueRange privateVars, ValueRange firstprivateVars,
1089                      ValueRange lastprivateVars, ValueRange linearVars,
1090                      ValueRange linearStepVars, ValueRange reductionVars,
1091                      StringAttr scheduleVal, Value scheduleChunkVar,
1092                      IntegerAttr collapseVal, UnitAttr nowait,
1093                      IntegerAttr orderedVal, StringAttr orderVal,
1094                      UnitAttr inclusive, bool buildBody) {
1095   result.addOperands(lowerBounds);
1096   result.addOperands(upperBounds);
1097   result.addOperands(steps);
1098   result.addOperands(privateVars);
1099   result.addOperands(firstprivateVars);
1100   result.addOperands(linearVars);
1101   result.addOperands(linearStepVars);
1102   if (scheduleChunkVar)
1103     result.addOperands(scheduleChunkVar);
1104 
1105   if (scheduleVal)
1106     result.addAttribute("schedule_val", scheduleVal);
1107   if (collapseVal)
1108     result.addAttribute("collapse_val", collapseVal);
1109   if (nowait)
1110     result.addAttribute("nowait", nowait);
1111   if (orderedVal)
1112     result.addAttribute("ordered_val", orderedVal);
1113   if (orderVal)
1114     result.addAttribute("order", orderVal);
1115   if (inclusive)
1116     result.addAttribute("inclusive", inclusive);
1117   result.addAttribute(
1118       WsLoopOp::getOperandSegmentSizeAttr(),
1119       builder.getI32VectorAttr(
1120           {static_cast<int32_t>(lowerBounds.size()),
1121            static_cast<int32_t>(upperBounds.size()),
1122            static_cast<int32_t>(steps.size()),
1123            static_cast<int32_t>(privateVars.size()),
1124            static_cast<int32_t>(firstprivateVars.size()),
1125            static_cast<int32_t>(lastprivateVars.size()),
1126            static_cast<int32_t>(linearVars.size()),
1127            static_cast<int32_t>(linearStepVars.size()),
1128            static_cast<int32_t>(reductionVars.size()),
1129            static_cast<int32_t>(scheduleChunkVar != nullptr ? 1 : 0)}));
1130 
1131   Region *bodyRegion = result.addRegion();
1132   if (buildBody) {
1133     OpBuilder::InsertionGuard guard(builder);
1134     unsigned numIVs = steps.size();
1135     SmallVector<Type, 8> argTypes(numIVs, steps.getType().front());
1136     builder.createBlock(bodyRegion, {}, argTypes);
1137   }
1138 }
1139 
1140 static LogicalResult verifyWsLoopOp(WsLoopOp op) {
1141   return verifyReductionVarList(op, op.reductions(), op.reduction_vars());
1142 }
1143 
1144 //===----------------------------------------------------------------------===//
1145 // Verifier for critical construct (2.17.1)
1146 //===----------------------------------------------------------------------===//
1147 
1148 static LogicalResult verifyCriticalDeclareOp(CriticalDeclareOp op) {
1149   return verifySynchronizationHint(op, op.hint());
1150 }
1151 
1152 static LogicalResult verifyCriticalOp(CriticalOp op) {
1153 
1154   if (op.nameAttr()) {
1155     auto symbolRef = op.nameAttr().cast<SymbolRefAttr>();
1156     auto decl =
1157         SymbolTable::lookupNearestSymbolFrom<CriticalDeclareOp>(op, symbolRef);
1158     if (!decl) {
1159       return op.emitOpError() << "expected symbol reference " << symbolRef
1160                               << " to point to a critical declaration";
1161     }
1162   }
1163 
1164   return success();
1165 }
1166 
1167 //===----------------------------------------------------------------------===//
1168 // Verifier for ordered construct
1169 //===----------------------------------------------------------------------===//
1170 
1171 static LogicalResult verifyOrderedOp(OrderedOp op) {
1172   auto container = op->getParentOfType<WsLoopOp>();
1173   if (!container || !container.ordered_valAttr() ||
1174       container.ordered_valAttr().getInt() == 0)
1175     return op.emitOpError() << "ordered depend directive must be closely "
1176                             << "nested inside a worksharing-loop with ordered "
1177                             << "clause with parameter present";
1178 
1179   if (container.ordered_valAttr().getInt() !=
1180       (int64_t)op.num_loops_val().getValue())
1181     return op.emitOpError() << "number of variables in depend clause does not "
1182                             << "match number of iteration variables in the "
1183                             << "doacross loop";
1184 
1185   return success();
1186 }
1187 
1188 static LogicalResult verifyOrderedRegionOp(OrderedRegionOp op) {
1189   // TODO: The code generation for ordered simd directive is not supported yet.
1190   if (op.simd())
1191     return failure();
1192 
1193   if (auto container = op->getParentOfType<WsLoopOp>()) {
1194     if (!container.ordered_valAttr() ||
1195         container.ordered_valAttr().getInt() != 0)
1196       return op.emitOpError() << "ordered region must be closely nested inside "
1197                               << "a worksharing-loop region with an ordered "
1198                               << "clause without parameter present";
1199   }
1200 
1201   return success();
1202 }
1203 
1204 //===----------------------------------------------------------------------===//
1205 // AtomicReadOp
1206 //===----------------------------------------------------------------------===//
1207 
1208 /// Parser for AtomicReadOp
1209 ///
1210 /// operation ::= `omp.atomic.read` atomic-clause-list address `->` result-type
1211 /// address ::= operand `:` type
1212 static ParseResult parseAtomicReadOp(OpAsmParser &parser,
1213                                      OperationState &result) {
1214   OpAsmParser::OperandType address;
1215   Type addressType;
1216   SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
1217   SmallVector<int> segments;
1218 
1219   if (parser.parseOperand(address) ||
1220       parseClauses(parser, result, clauses, segments) ||
1221       parser.parseColonType(addressType) ||
1222       parser.resolveOperand(address, addressType, result.operands))
1223     return failure();
1224 
1225   SmallVector<Type> resultType;
1226   if (parser.parseArrowTypeList(resultType))
1227     return failure();
1228   result.addTypes(resultType);
1229   return success();
1230 }
1231 
1232 /// Printer for AtomicReadOp
1233 static void printAtomicReadOp(OpAsmPrinter &p, AtomicReadOp op) {
1234   p << " " << op.address() << " ";
1235   if (op.memory_order())
1236     p << "memory_order(" << op.memory_order().getValue() << ") ";
1237   if (op.hintAttr())
1238     printSynchronizationHint(p << " ", op, op.hintAttr());
1239   p << ": " << op.address().getType() << " -> " << op.getType();
1240   return;
1241 }
1242 
1243 /// Verifier for AtomicReadOp
1244 static LogicalResult verifyAtomicReadOp(AtomicReadOp op) {
1245   if (op.memory_order()) {
1246     StringRef memOrder = op.memory_order().getValue();
1247     if (memOrder.equals("acq_rel") || memOrder.equals("release"))
1248       return op.emitError(
1249           "memory-order must not be acq_rel or release for atomic reads");
1250   }
1251   return verifySynchronizationHint(op, op.hint());
1252 }
1253 
1254 //===----------------------------------------------------------------------===//
1255 // AtomicWriteOp
1256 //===----------------------------------------------------------------------===//
1257 
1258 /// Parser for AtomicWriteOp
1259 ///
1260 /// operation ::= `omp.atomic.write` atomic-clause-list operands
1261 /// operands ::= address `,` value
1262 /// address ::= operand `:` type
1263 /// value ::= operand `:` type
1264 static ParseResult parseAtomicWriteOp(OpAsmParser &parser,
1265                                       OperationState &result) {
1266   OpAsmParser::OperandType address, value;
1267   Type addrType, valueType;
1268   SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
1269   SmallVector<int> segments;
1270 
1271   if (parser.parseOperand(address) || parser.parseComma() ||
1272       parser.parseOperand(value) ||
1273       parseClauses(parser, result, clauses, segments) ||
1274       parser.parseColonType(addrType) || parser.parseComma() ||
1275       parser.parseType(valueType) ||
1276       parser.resolveOperand(address, addrType, result.operands) ||
1277       parser.resolveOperand(value, valueType, result.operands))
1278     return failure();
1279   return success();
1280 }
1281 
1282 /// Printer for AtomicWriteOp
1283 static void printAtomicWriteOp(OpAsmPrinter &p, AtomicWriteOp op) {
1284   p << " " << op.address() << ", " << op.value() << " ";
1285   if (op.memory_order())
1286     p << "memory_order(" << op.memory_order() << ") ";
1287   if (op.hintAttr())
1288     printSynchronizationHint(p, op, op.hintAttr());
1289   p << ": " << op.address().getType() << ", " << op.value().getType();
1290   return;
1291 }
1292 
1293 /// Verifier for AtomicWriteOp
1294 static LogicalResult verifyAtomicWriteOp(AtomicWriteOp op) {
1295   if (op.memory_order()) {
1296     StringRef memoryOrder = op.memory_order().getValue();
1297     if (memoryOrder.equals("acq_rel") || memoryOrder.equals("acquire"))
1298       return op.emitError(
1299           "memory-order must not be acq_rel or acquire for atomic writes");
1300   }
1301   return verifySynchronizationHint(op, op.hint());
1302 }
1303 
1304 #define GET_OP_CLASSES
1305 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
1306