1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the OpenMP dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/OpImplementation.h"
18 #include "mlir/IR/OperationSupport.h"
19 
20 #include "llvm/ADT/BitVector.h"
21 #include "llvm/ADT/SmallString.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/ADT/StringSwitch.h"
25 #include <cstddef>
26 
27 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
28 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
29 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
30 
31 using namespace mlir;
32 using namespace mlir::omp;
33 
34 namespace {
35 /// Model for pointer-like types that already provide a `getElementType` method.
36 template <typename T>
37 struct PointerLikeModel
38     : public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> {
39   Type getElementType(Type pointer) const {
40     return pointer.cast<T>().getElementType();
41   }
42 };
43 } // end namespace
44 
45 void OpenMPDialect::initialize() {
46   addOperations<
47 #define GET_OP_LIST
48 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
49       >();
50 
51   LLVM::LLVMPointerType::attachInterface<
52       PointerLikeModel<LLVM::LLVMPointerType>>(*getContext());
53   MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext());
54 }
55 
56 //===----------------------------------------------------------------------===//
57 // ParallelOp
58 //===----------------------------------------------------------------------===//
59 
60 void ParallelOp::build(OpBuilder &builder, OperationState &state,
61                        ArrayRef<NamedAttribute> attributes) {
62   ParallelOp::build(
63       builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
64       /*default_val=*/nullptr, /*private_vars=*/ValueRange(),
65       /*firstprivate_vars=*/ValueRange(), /*shared_vars=*/ValueRange(),
66       /*copyin_vars=*/ValueRange(), /*allocate_vars=*/ValueRange(),
67       /*allocators_vars=*/ValueRange(), /*proc_bind_val=*/nullptr);
68   state.addAttributes(attributes);
69 }
70 
71 //===----------------------------------------------------------------------===//
72 // Parser and printer for Operand and type list
73 //===----------------------------------------------------------------------===//
74 
75 /// Parse a list of operands with types.
76 ///
77 /// operand-and-type-list ::= `(` ssa-id-and-type-list `)`
78 /// ssa-id-and-type-list ::= ssa-id-and-type |
79 ///                          ssa-id-and-type `,` ssa-id-and-type-list
80 /// ssa-id-and-type ::= ssa-id `:` type
81 static ParseResult
82 parseOperandAndTypeList(OpAsmParser &parser,
83                         SmallVectorImpl<OpAsmParser::OperandType> &operands,
84                         SmallVectorImpl<Type> &types) {
85   return parser.parseCommaSeparatedList(
86       OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
87         OpAsmParser::OperandType operand;
88         Type type;
89         if (parser.parseOperand(operand) || parser.parseColonType(type))
90           return failure();
91         operands.push_back(operand);
92         types.push_back(type);
93         return success();
94       });
95 }
96 
97 /// Print an operand and type list with parentheses
98 static void printOperandAndTypeList(OpAsmPrinter &p, OperandRange operands) {
99   p << "(";
100   llvm::interleaveComma(
101       operands, p, [&](const Value &v) { p << v << " : " << v.getType(); });
102   p << ") ";
103 }
104 
105 /// Print data variables corresponding to a data-sharing clause `name`
106 static void printDataVars(OpAsmPrinter &p, OperandRange operands,
107                           StringRef name) {
108   if (operands.size()) {
109     p << name;
110     printOperandAndTypeList(p, operands);
111   }
112 }
113 
114 //===----------------------------------------------------------------------===//
115 // Parser and printer for Allocate Clause
116 //===----------------------------------------------------------------------===//
117 
118 /// Parse an allocate clause with allocators and a list of operands with types.
119 ///
120 /// allocate ::= `allocate` `(` allocate-operand-list `)`
121 /// allocate-operand-list :: = allocate-operand |
122 ///                            allocator-operand `,` allocate-operand-list
123 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
124 /// ssa-id-and-type ::= ssa-id `:` type
125 static ParseResult parseAllocateAndAllocator(
126     OpAsmParser &parser,
127     SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate,
128     SmallVectorImpl<Type> &typesAllocate,
129     SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator,
130     SmallVectorImpl<Type> &typesAllocator) {
131 
132   return parser.parseCommaSeparatedList(
133       OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
134         OpAsmParser::OperandType operand;
135         Type type;
136         if (parser.parseOperand(operand) || parser.parseColonType(type))
137           return failure();
138         operandsAllocator.push_back(operand);
139         typesAllocator.push_back(type);
140         if (parser.parseArrow())
141           return failure();
142         if (parser.parseOperand(operand) || parser.parseColonType(type))
143           return failure();
144 
145         operandsAllocate.push_back(operand);
146         typesAllocate.push_back(type);
147         return success();
148       });
149 }
150 
151 /// Print allocate clause
152 static void printAllocateAndAllocator(OpAsmPrinter &p,
153                                       OperandRange varsAllocate,
154                                       OperandRange varsAllocator) {
155   p << "allocate(";
156   for (unsigned i = 0; i < varsAllocate.size(); ++i) {
157     std::string separator = i == varsAllocate.size() - 1 ? ") " : ", ";
158     p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> ";
159     p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator;
160   }
161 }
162 
163 static LogicalResult verifyParallelOp(ParallelOp op) {
164   if (op.allocate_vars().size() != op.allocators_vars().size())
165     return op.emitError(
166         "expected equal sizes for allocate and allocator variables");
167   return success();
168 }
169 
170 static void printParallelOp(OpAsmPrinter &p, ParallelOp op) {
171   p << " ";
172   if (auto ifCond = op.if_expr_var())
173     p << "if(" << ifCond << " : " << ifCond.getType() << ") ";
174 
175   if (auto threads = op.num_threads_var())
176     p << "num_threads(" << threads << " : " << threads.getType() << ") ";
177 
178   printDataVars(p, op.private_vars(), "private");
179   printDataVars(p, op.firstprivate_vars(), "firstprivate");
180   printDataVars(p, op.shared_vars(), "shared");
181   printDataVars(p, op.copyin_vars(), "copyin");
182 
183   if (!op.allocate_vars().empty())
184     printAllocateAndAllocator(p, op.allocate_vars(), op.allocators_vars());
185 
186   if (auto def = op.default_val())
187     p << "default(" << def->drop_front(3) << ") ";
188 
189   if (auto bind = op.proc_bind_val())
190     p << "proc_bind(" << bind << ") ";
191 
192   p.printRegion(op.getRegion());
193 }
194 
195 //===----------------------------------------------------------------------===//
196 // Parser and printer for Linear Clause
197 //===----------------------------------------------------------------------===//
198 
199 /// linear ::= `linear` `(` linear-list `)`
200 /// linear-list := linear-val | linear-val linear-list
201 /// linear-val := ssa-id-and-type `=` ssa-id-and-type
202 static ParseResult
203 parseLinearClause(OpAsmParser &parser,
204                   SmallVectorImpl<OpAsmParser::OperandType> &vars,
205                   SmallVectorImpl<Type> &types,
206                   SmallVectorImpl<OpAsmParser::OperandType> &stepVars) {
207   if (parser.parseLParen())
208     return failure();
209 
210   do {
211     OpAsmParser::OperandType var;
212     Type type;
213     OpAsmParser::OperandType stepVar;
214     if (parser.parseOperand(var) || parser.parseEqual() ||
215         parser.parseOperand(stepVar) || parser.parseColonType(type))
216       return failure();
217 
218     vars.push_back(var);
219     types.push_back(type);
220     stepVars.push_back(stepVar);
221   } while (succeeded(parser.parseOptionalComma()));
222 
223   if (parser.parseRParen())
224     return failure();
225 
226   return success();
227 }
228 
229 /// Print Linear Clause
230 static void printLinearClause(OpAsmPrinter &p, OperandRange linearVars,
231                               OperandRange linearStepVars) {
232   size_t linearVarsSize = linearVars.size();
233   p << "linear(";
234   for (unsigned i = 0; i < linearVarsSize; ++i) {
235     std::string separator = i == linearVarsSize - 1 ? ") " : ", ";
236     p << linearVars[i];
237     if (linearStepVars.size() > i)
238       p << " = " << linearStepVars[i];
239     p << " : " << linearVars[i].getType() << separator;
240   }
241 }
242 
243 //===----------------------------------------------------------------------===//
244 // Parser and printer for Schedule Clause
245 //===----------------------------------------------------------------------===//
246 
247 static ParseResult
248 verifyScheduleModifiers(OpAsmParser &parser,
249                         SmallVectorImpl<SmallString<12>> &modifiers) {
250   if (modifiers.size() > 2)
251     return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)";
252   for (auto mod : modifiers) {
253     // Translate the string. If it has no value, then it was not a valid
254     // modifier!
255     auto symbol = symbolizeScheduleModifier(mod);
256     if (!symbol.hasValue())
257       return parser.emitError(parser.getNameLoc())
258              << " unknown modifier type: " << mod;
259   }
260 
261   // If we have one modifier that is "simd", then stick a "none" modiifer in
262   // index 0.
263   if (modifiers.size() == 1) {
264     if (symbolizeScheduleModifier(modifiers[0]) ==
265         mlir::omp::ScheduleModifier::simd) {
266       modifiers.push_back(modifiers[0]);
267       modifiers[0] =
268           stringifyScheduleModifier(mlir::omp::ScheduleModifier::none);
269     }
270   } else if (modifiers.size() == 2) {
271     // If there are two modifier:
272     // First modifier should not be simd, second one should be simd
273     if (symbolizeScheduleModifier(modifiers[0]) ==
274             mlir::omp::ScheduleModifier::simd ||
275         symbolizeScheduleModifier(modifiers[1]) !=
276             mlir::omp::ScheduleModifier::simd)
277       return parser.emitError(parser.getNameLoc())
278              << " incorrect modifier order";
279   }
280   return success();
281 }
282 
283 /// schedule ::= `schedule` `(` sched-list `)`
284 /// sched-list ::= sched-val | sched-val sched-list |
285 ///                sched-val `,` sched-modifier
286 /// sched-val ::= sched-with-chunk | sched-wo-chunk
287 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
288 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
289 /// sched-wo-chunk ::=  `auto` | `runtime`
290 /// sched-modifier ::=  sched-mod-val | sched-mod-val `,` sched-mod-val
291 /// sched-mod-val ::=  `monotonic` | `nonmonotonic` | `simd` | `none`
292 static ParseResult
293 parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule,
294                     SmallVectorImpl<SmallString<12>> &modifiers,
295                     Optional<OpAsmParser::OperandType> &chunkSize) {
296   if (parser.parseLParen())
297     return failure();
298 
299   StringRef keyword;
300   if (parser.parseKeyword(&keyword))
301     return failure();
302 
303   schedule = keyword;
304   if (keyword == "static" || keyword == "dynamic" || keyword == "guided") {
305     if (succeeded(parser.parseOptionalEqual())) {
306       chunkSize = OpAsmParser::OperandType{};
307       if (parser.parseOperand(*chunkSize))
308         return failure();
309     } else {
310       chunkSize = llvm::NoneType::None;
311     }
312   } else if (keyword == "auto" || keyword == "runtime") {
313     chunkSize = llvm::NoneType::None;
314   } else {
315     return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
316   }
317 
318   // If there is a comma, we have one or more modifiers..
319   while (succeeded(parser.parseOptionalComma())) {
320     StringRef mod;
321     if (parser.parseKeyword(&mod))
322       return failure();
323     modifiers.push_back(mod);
324   }
325 
326   if (parser.parseRParen())
327     return failure();
328 
329   if (verifyScheduleModifiers(parser, modifiers))
330     return failure();
331 
332   return success();
333 }
334 
335 /// Print schedule clause
336 static void printScheduleClause(OpAsmPrinter &p, StringRef &sched,
337                                 llvm::Optional<StringRef> modifier, bool simd,
338                                 Value scheduleChunkVar) {
339   std::string schedLower = sched.lower();
340   p << "schedule(" << schedLower;
341   if (scheduleChunkVar)
342     p << " = " << scheduleChunkVar;
343   if (modifier && modifier.hasValue())
344     p << ", " << modifier;
345   if (simd)
346     p << ", simd";
347   p << ") ";
348 }
349 
350 //===----------------------------------------------------------------------===//
351 // Parser, printer and verifier for ReductionVarList
352 //===----------------------------------------------------------------------===//
353 
354 /// reduction ::= `reduction` `(` reduction-entry-list `)`
355 /// reduction-entry-list ::= reduction-entry
356 ///                        | reduction-entry-list `,` reduction-entry
357 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type
358 static ParseResult
359 parseReductionVarList(OpAsmParser &parser,
360                       SmallVectorImpl<SymbolRefAttr> &symbols,
361                       SmallVectorImpl<OpAsmParser::OperandType> &operands,
362                       SmallVectorImpl<Type> &types) {
363   if (failed(parser.parseLParen()))
364     return failure();
365 
366   do {
367     if (parser.parseAttribute(symbols.emplace_back()) || parser.parseArrow() ||
368         parser.parseOperand(operands.emplace_back()) ||
369         parser.parseColonType(types.emplace_back()))
370       return failure();
371   } while (succeeded(parser.parseOptionalComma()));
372   return parser.parseRParen();
373 }
374 
375 /// Print Reduction clause
376 static void printReductionVarList(OpAsmPrinter &p,
377                                   Optional<ArrayAttr> reductions,
378                                   OperandRange reduction_vars) {
379   p << "reduction(";
380   for (unsigned i = 0, e = reductions->size(); i < e; ++i) {
381     if (i != 0)
382       p << ", ";
383     p << (*reductions)[i] << " -> " << reduction_vars[i] << " : "
384       << reduction_vars[i].getType();
385   }
386   p << ") ";
387 }
388 
389 /// Verifies Reduction Clause
390 static LogicalResult verifyReductionVarList(Operation *op,
391                                             Optional<ArrayAttr> reductions,
392                                             OperandRange reduction_vars) {
393   if (reduction_vars.size() != 0) {
394     if (!reductions || reductions->size() != reduction_vars.size())
395       return op->emitOpError()
396              << "expected as many reduction symbol references "
397                 "as reduction variables";
398   } else {
399     if (reductions)
400       return op->emitOpError() << "unexpected reduction symbol references";
401     return success();
402   }
403 
404   DenseSet<Value> accumulators;
405   for (auto args : llvm::zip(reduction_vars, *reductions)) {
406     Value accum = std::get<0>(args);
407 
408     if (!accumulators.insert(accum).second)
409       return op->emitOpError() << "accumulator variable used more than once";
410 
411     Type varType = accum.getType().cast<PointerLikeType>();
412     auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>();
413     auto decl =
414         SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef);
415     if (!decl)
416       return op->emitOpError() << "expected symbol reference " << symbolRef
417                                << " to point to a reduction declaration";
418 
419     if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
420       return op->emitOpError()
421              << "expected accumulator (" << varType
422              << ") to be the same type as reduction declaration ("
423              << decl.getAccumulatorType() << ")";
424   }
425 
426   return success();
427 }
428 
429 //===----------------------------------------------------------------------===//
430 // Parser, printer and verifier for Synchronization Hint (2.17.12)
431 //===----------------------------------------------------------------------===//
432 
433 /// Parses a Synchronization Hint clause. The value of hint is an integer
434 /// which is a combination of different hints from `omp_sync_hint_t`.
435 ///
436 /// hint-clause = `hint` `(` hint-value `)`
437 static ParseResult parseSynchronizationHint(OpAsmParser &parser,
438                                             IntegerAttr &hintAttr,
439                                             bool parseKeyword = true) {
440   if (parseKeyword && failed(parser.parseOptionalKeyword("hint"))) {
441     hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
442     return success();
443   }
444 
445   if (failed(parser.parseLParen()))
446     return failure();
447   StringRef hintKeyword;
448   int64_t hint = 0;
449   do {
450     if (failed(parser.parseKeyword(&hintKeyword)))
451       return failure();
452     if (hintKeyword == "uncontended")
453       hint |= 1;
454     else if (hintKeyword == "contended")
455       hint |= 2;
456     else if (hintKeyword == "nonspeculative")
457       hint |= 4;
458     else if (hintKeyword == "speculative")
459       hint |= 8;
460     else
461       return parser.emitError(parser.getCurrentLocation())
462              << hintKeyword << " is not a valid hint";
463   } while (succeeded(parser.parseOptionalComma()));
464   if (failed(parser.parseRParen()))
465     return failure();
466   hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
467   return success();
468 }
469 
470 /// Prints a Synchronization Hint clause
471 static void printSynchronizationHint(OpAsmPrinter &p, Operation *op,
472                                      IntegerAttr hintAttr) {
473   int64_t hint = hintAttr.getInt();
474 
475   if (hint == 0)
476     return;
477 
478   // Helper function to get n-th bit from the right end of `value`
479   auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
480 
481   bool uncontended = bitn(hint, 0);
482   bool contended = bitn(hint, 1);
483   bool nonspeculative = bitn(hint, 2);
484   bool speculative = bitn(hint, 3);
485 
486   SmallVector<StringRef> hints;
487   if (uncontended)
488     hints.push_back("uncontended");
489   if (contended)
490     hints.push_back("contended");
491   if (nonspeculative)
492     hints.push_back("nonspeculative");
493   if (speculative)
494     hints.push_back("speculative");
495 
496   p << "hint(";
497   llvm::interleaveComma(hints, p);
498   p << ") ";
499 }
500 
501 /// Verifies a synchronization hint clause
502 static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
503 
504   // Helper function to get n-th bit from the right end of `value`
505   auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
506 
507   bool uncontended = bitn(hint, 0);
508   bool contended = bitn(hint, 1);
509   bool nonspeculative = bitn(hint, 2);
510   bool speculative = bitn(hint, 3);
511 
512   if (uncontended && contended)
513     return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
514                                 "omp_sync_hint_contended cannot be combined";
515   if (nonspeculative && speculative)
516     return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
517                                 "omp_sync_hint_speculative cannot be combined.";
518   return success();
519 }
520 
521 enum ClauseType {
522   ifClause,
523   numThreadsClause,
524   privateClause,
525   firstprivateClause,
526   lastprivateClause,
527   sharedClause,
528   copyinClause,
529   allocateClause,
530   defaultClause,
531   procBindClause,
532   reductionClause,
533   nowaitClause,
534   linearClause,
535   scheduleClause,
536   collapseClause,
537   orderClause,
538   orderedClause,
539   memoryOrderClause,
540   hintClause,
541   COUNT
542 };
543 
544 //===----------------------------------------------------------------------===//
545 // Parser for Clause List
546 //===----------------------------------------------------------------------===//
547 
548 /// Parse a list of clauses. The clauses can appear in any order, but their
549 /// operand segment indices are in the same order that they are passed in the
550 /// `clauses` list. The operand segments are added over the prevSegments
551 
552 /// clause-list ::= clause clause-list | empty
553 /// clause ::= if | num-threads | private | firstprivate | lastprivate |
554 ///            shared | copyin | allocate | default | proc-bind | reduction |
555 ///            nowait | linear | schedule | collapse | order | ordered |
556 ///            inclusive
557 /// if ::= `if` `(` ssa-id-and-type `)`
558 /// num-threads ::= `num_threads` `(` ssa-id-and-type `)`
559 /// private ::= `private` operand-and-type-list
560 /// firstprivate ::= `firstprivate` operand-and-type-list
561 /// lastprivate ::= `lastprivate` operand-and-type-list
562 /// shared ::= `shared` operand-and-type-list
563 /// copyin ::= `copyin` operand-and-type-list
564 /// allocate ::= `allocate` `(` allocate-operand-list `)`
565 /// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`)
566 /// proc-bind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)`
567 /// reduction ::= `reduction` `(` reduction-entry-list `)`
568 /// nowait ::= `nowait`
569 /// linear ::= `linear` `(` linear-list `)`
570 /// schedule ::= `schedule` `(` sched-list `)`
571 /// collapse ::= `collapse` `(` ssa-id-and-type `)`
572 /// order ::= `order` `(` `concurrent` `)`
573 /// ordered ::= `ordered` `(` ssa-id-and-type `)`
574 /// inclusive ::= `inclusive`
575 ///
576 /// Note that each clause can only appear once in the clase-list.
577 static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
578                                 SmallVectorImpl<ClauseType> &clauses,
579                                 SmallVectorImpl<int> &segments) {
580 
581   // Check done[clause] to see if it has been parsed already
582   llvm::BitVector done(ClauseType::COUNT, false);
583 
584   // See pos[clause] to get position of clause in operand segments
585   SmallVector<int> pos(ClauseType::COUNT, -1);
586 
587   // Stores the last parsed clause keyword
588   StringRef clauseKeyword;
589   StringRef opName = result.name.getStringRef();
590 
591   // Containers for storing operands, types and attributes for various clauses
592   std::pair<OpAsmParser::OperandType, Type> ifCond;
593   std::pair<OpAsmParser::OperandType, Type> numThreads;
594 
595   SmallVector<OpAsmParser::OperandType> privates, firstprivates, lastprivates,
596       shareds, copyins;
597   SmallVector<Type> privateTypes, firstprivateTypes, lastprivateTypes,
598       sharedTypes, copyinTypes;
599 
600   SmallVector<OpAsmParser::OperandType> allocates, allocators;
601   SmallVector<Type> allocateTypes, allocatorTypes;
602 
603   SmallVector<SymbolRefAttr> reductionSymbols;
604   SmallVector<OpAsmParser::OperandType> reductionVars;
605   SmallVector<Type> reductionVarTypes;
606 
607   SmallVector<OpAsmParser::OperandType> linears;
608   SmallVector<Type> linearTypes;
609   SmallVector<OpAsmParser::OperandType> linearSteps;
610 
611   SmallString<8> schedule;
612   SmallVector<SmallString<12>> modifiers;
613   Optional<OpAsmParser::OperandType> scheduleChunkSize;
614 
615   // Compute the position of clauses in operand segments
616   int currPos = 0;
617   for (ClauseType clause : clauses) {
618 
619     // Skip the following clauses - they do not take any position in operand
620     // segments
621     if (clause == defaultClause || clause == procBindClause ||
622         clause == nowaitClause || clause == collapseClause ||
623         clause == orderClause || clause == orderedClause)
624       continue;
625 
626     pos[clause] = currPos++;
627 
628     // For the following clauses, two positions are reserved in the operand
629     // segments
630     if (clause == allocateClause || clause == linearClause)
631       currPos++;
632   }
633 
634   SmallVector<int> clauseSegments(currPos);
635 
636   // Helper function to check if a clause is allowed/repeated or not
637   auto checkAllowed = [&](ClauseType clause,
638                           bool allowRepeat = false) -> ParseResult {
639     if (!llvm::is_contained(clauses, clause))
640       return parser.emitError(parser.getCurrentLocation())
641              << clauseKeyword << " is not a valid clause for the " << opName
642              << " operation";
643     if (done[clause] && !allowRepeat)
644       return parser.emitError(parser.getCurrentLocation())
645              << "at most one " << clauseKeyword << " clause can appear on the "
646              << opName << " operation";
647     done[clause] = true;
648     return success();
649   };
650 
651   while (succeeded(parser.parseOptionalKeyword(&clauseKeyword))) {
652     if (clauseKeyword == "if") {
653       if (checkAllowed(ifClause) || parser.parseLParen() ||
654           parser.parseOperand(ifCond.first) ||
655           parser.parseColonType(ifCond.second) || parser.parseRParen())
656         return failure();
657       clauseSegments[pos[ifClause]] = 1;
658     } else if (clauseKeyword == "num_threads") {
659       if (checkAllowed(numThreadsClause) || parser.parseLParen() ||
660           parser.parseOperand(numThreads.first) ||
661           parser.parseColonType(numThreads.second) || parser.parseRParen())
662         return failure();
663       clauseSegments[pos[numThreadsClause]] = 1;
664     } else if (clauseKeyword == "private") {
665       if (checkAllowed(privateClause) ||
666           parseOperandAndTypeList(parser, privates, privateTypes))
667         return failure();
668       clauseSegments[pos[privateClause]] = privates.size();
669     } else if (clauseKeyword == "firstprivate") {
670       if (checkAllowed(firstprivateClause) ||
671           parseOperandAndTypeList(parser, firstprivates, firstprivateTypes))
672         return failure();
673       clauseSegments[pos[firstprivateClause]] = firstprivates.size();
674     } else if (clauseKeyword == "lastprivate") {
675       if (checkAllowed(lastprivateClause) ||
676           parseOperandAndTypeList(parser, lastprivates, lastprivateTypes))
677         return failure();
678       clauseSegments[pos[lastprivateClause]] = lastprivates.size();
679     } else if (clauseKeyword == "shared") {
680       if (checkAllowed(sharedClause) ||
681           parseOperandAndTypeList(parser, shareds, sharedTypes))
682         return failure();
683       clauseSegments[pos[sharedClause]] = shareds.size();
684     } else if (clauseKeyword == "copyin") {
685       if (checkAllowed(copyinClause) ||
686           parseOperandAndTypeList(parser, copyins, copyinTypes))
687         return failure();
688       clauseSegments[pos[copyinClause]] = copyins.size();
689     } else if (clauseKeyword == "allocate") {
690       if (checkAllowed(allocateClause) ||
691           parseAllocateAndAllocator(parser, allocates, allocateTypes,
692                                     allocators, allocatorTypes))
693         return failure();
694       clauseSegments[pos[allocateClause]] = allocates.size();
695       clauseSegments[pos[allocateClause] + 1] = allocators.size();
696     } else if (clauseKeyword == "default") {
697       StringRef defval;
698       if (checkAllowed(defaultClause) || parser.parseLParen() ||
699           parser.parseKeyword(&defval) || parser.parseRParen())
700         return failure();
701       // The def prefix is required for the attribute as "private" is a keyword
702       // in C++.
703       auto attr = parser.getBuilder().getStringAttr("def" + defval);
704       result.addAttribute("default_val", attr);
705     } else if (clauseKeyword == "proc_bind") {
706       StringRef bind;
707       if (checkAllowed(procBindClause) || parser.parseLParen() ||
708           parser.parseKeyword(&bind) || parser.parseRParen())
709         return failure();
710       auto attr = parser.getBuilder().getStringAttr(bind);
711       result.addAttribute("proc_bind_val", attr);
712     } else if (clauseKeyword == "reduction") {
713       if (checkAllowed(reductionClause) ||
714           parseReductionVarList(parser, reductionSymbols, reductionVars,
715                                 reductionVarTypes))
716         return failure();
717       clauseSegments[pos[reductionClause]] = reductionVars.size();
718     } else if (clauseKeyword == "nowait") {
719       if (checkAllowed(nowaitClause))
720         return failure();
721       auto attr = UnitAttr::get(parser.getBuilder().getContext());
722       result.addAttribute("nowait", attr);
723     } else if (clauseKeyword == "linear") {
724       if (checkAllowed(linearClause) ||
725           parseLinearClause(parser, linears, linearTypes, linearSteps))
726         return failure();
727       clauseSegments[pos[linearClause]] = linears.size();
728       clauseSegments[pos[linearClause] + 1] = linearSteps.size();
729     } else if (clauseKeyword == "schedule") {
730       if (checkAllowed(scheduleClause) ||
731           parseScheduleClause(parser, schedule, modifiers, scheduleChunkSize))
732         return failure();
733       if (scheduleChunkSize) {
734         clauseSegments[pos[scheduleClause]] = 1;
735       }
736     } else if (clauseKeyword == "collapse") {
737       auto type = parser.getBuilder().getI64Type();
738       mlir::IntegerAttr attr;
739       if (checkAllowed(collapseClause) || parser.parseLParen() ||
740           parser.parseAttribute(attr, type) || parser.parseRParen())
741         return failure();
742       result.addAttribute("collapse_val", attr);
743     } else if (clauseKeyword == "ordered") {
744       mlir::IntegerAttr attr;
745       if (checkAllowed(orderedClause))
746         return failure();
747       if (succeeded(parser.parseOptionalLParen())) {
748         auto type = parser.getBuilder().getI64Type();
749         if (parser.parseAttribute(attr, type) || parser.parseRParen())
750           return failure();
751       } else {
752         // Use 0 to represent no ordered parameter was specified
753         attr = parser.getBuilder().getI64IntegerAttr(0);
754       }
755       result.addAttribute("ordered_val", attr);
756     } else if (clauseKeyword == "order") {
757       StringRef order;
758       if (checkAllowed(orderClause) || parser.parseLParen() ||
759           parser.parseKeyword(&order) || parser.parseRParen())
760         return failure();
761       auto attr = parser.getBuilder().getStringAttr(order);
762       result.addAttribute("order_val", attr);
763     } else if (clauseKeyword == "memory_order") {
764       StringRef memoryOrder;
765       if (checkAllowed(memoryOrderClause) || parser.parseLParen() ||
766           parser.parseKeyword(&memoryOrder) || parser.parseRParen())
767         return failure();
768       result.addAttribute("memory_order",
769                           parser.getBuilder().getStringAttr(memoryOrder));
770     } else if (clauseKeyword == "hint") {
771       IntegerAttr hint;
772       if (checkAllowed(hintClause) ||
773           parseSynchronizationHint(parser, hint, false))
774         return failure();
775       result.addAttribute("hint", hint);
776     } else {
777       return parser.emitError(parser.getNameLoc())
778              << clauseKeyword << " is not a valid clause";
779     }
780   }
781 
782   // Add if parameter.
783   if (done[ifClause] && clauseSegments[pos[ifClause]] &&
784       failed(
785           parser.resolveOperand(ifCond.first, ifCond.second, result.operands)))
786     return failure();
787 
788   // Add num_threads parameter.
789   if (done[numThreadsClause] && clauseSegments[pos[numThreadsClause]] &&
790       failed(parser.resolveOperand(numThreads.first, numThreads.second,
791                                    result.operands)))
792     return failure();
793 
794   // Add private parameters.
795   if (done[privateClause] && clauseSegments[pos[privateClause]] &&
796       failed(parser.resolveOperands(privates, privateTypes,
797                                     privates[0].location, result.operands)))
798     return failure();
799 
800   // Add firstprivate parameters.
801   if (done[firstprivateClause] && clauseSegments[pos[firstprivateClause]] &&
802       failed(parser.resolveOperands(firstprivates, firstprivateTypes,
803                                     firstprivates[0].location,
804                                     result.operands)))
805     return failure();
806 
807   // Add lastprivate parameters.
808   if (done[lastprivateClause] && clauseSegments[pos[lastprivateClause]] &&
809       failed(parser.resolveOperands(lastprivates, lastprivateTypes,
810                                     lastprivates[0].location, result.operands)))
811     return failure();
812 
813   // Add shared parameters.
814   if (done[sharedClause] && clauseSegments[pos[sharedClause]] &&
815       failed(parser.resolveOperands(shareds, sharedTypes, shareds[0].location,
816                                     result.operands)))
817     return failure();
818 
819   // Add copyin parameters.
820   if (done[copyinClause] && clauseSegments[pos[copyinClause]] &&
821       failed(parser.resolveOperands(copyins, copyinTypes, copyins[0].location,
822                                     result.operands)))
823     return failure();
824 
825   // Add allocate parameters.
826   if (done[allocateClause] && clauseSegments[pos[allocateClause]] &&
827       failed(parser.resolveOperands(allocates, allocateTypes,
828                                     allocates[0].location, result.operands)))
829     return failure();
830 
831   // Add allocator parameters.
832   if (done[allocateClause] && clauseSegments[pos[allocateClause] + 1] &&
833       failed(parser.resolveOperands(allocators, allocatorTypes,
834                                     allocators[0].location, result.operands)))
835     return failure();
836 
837   // Add reduction parameters and symbols
838   if (done[reductionClause] && clauseSegments[pos[reductionClause]]) {
839     if (failed(parser.resolveOperands(reductionVars, reductionVarTypes,
840                                       parser.getNameLoc(), result.operands)))
841       return failure();
842 
843     SmallVector<Attribute> reductions(reductionSymbols.begin(),
844                                       reductionSymbols.end());
845     result.addAttribute("reductions",
846                         parser.getBuilder().getArrayAttr(reductions));
847   }
848 
849   // Add linear parameters
850   if (done[linearClause] && clauseSegments[pos[linearClause]]) {
851     auto linearStepType = parser.getBuilder().getI32Type();
852     SmallVector<Type> linearStepTypes(linearSteps.size(), linearStepType);
853     if (failed(parser.resolveOperands(linears, linearTypes, linears[0].location,
854                                       result.operands)) ||
855         failed(parser.resolveOperands(linearSteps, linearStepTypes,
856                                       linearSteps[0].location,
857                                       result.operands)))
858       return failure();
859   }
860 
861   // Add schedule parameters
862   if (done[scheduleClause] && !schedule.empty()) {
863     schedule[0] = llvm::toUpper(schedule[0]);
864     auto attr = parser.getBuilder().getStringAttr(schedule);
865     result.addAttribute("schedule_val", attr);
866     if (modifiers.size() > 0) {
867       auto mod = parser.getBuilder().getStringAttr(modifiers[0]);
868       result.addAttribute("schedule_modifier", mod);
869       // Only SIMD attribute is allowed here!
870       if (modifiers.size() > 1) {
871         assert(symbolizeScheduleModifier(modifiers[1]) ==
872                mlir::omp::ScheduleModifier::simd);
873         auto attr = UnitAttr::get(parser.getBuilder().getContext());
874         result.addAttribute("simd_modifier", attr);
875       }
876     }
877     if (scheduleChunkSize) {
878       auto chunkSizeType = parser.getBuilder().getI32Type();
879       parser.resolveOperand(*scheduleChunkSize, chunkSizeType, result.operands);
880     }
881   }
882 
883   segments.insert(segments.end(), clauseSegments.begin(), clauseSegments.end());
884 
885   return success();
886 }
887 
888 /// Parses a parallel operation.
889 ///
890 /// operation ::= `omp.parallel` clause-list
891 /// clause-list ::= clause | clause clause-list
892 /// clause ::= if | num-threads | private | firstprivate | shared | copyin |
893 ///            allocate | default | proc-bind
894 ///
895 static ParseResult parseParallelOp(OpAsmParser &parser,
896                                    OperationState &result) {
897   SmallVector<ClauseType> clauses = {
898       ifClause,           numThreadsClause, privateClause,
899       firstprivateClause, sharedClause,     copyinClause,
900       allocateClause,     defaultClause,    procBindClause};
901 
902   SmallVector<int> segments;
903 
904   if (failed(parseClauses(parser, result, clauses, segments)))
905     return failure();
906 
907   result.addAttribute("operand_segment_sizes",
908                       parser.getBuilder().getI32VectorAttr(segments));
909 
910   Region *body = result.addRegion();
911   SmallVector<OpAsmParser::OperandType> regionArgs;
912   SmallVector<Type> regionArgTypes;
913   if (parser.parseRegion(*body, regionArgs, regionArgTypes))
914     return failure();
915   return success();
916 }
917 
918 //===----------------------------------------------------------------------===//
919 // Parser, printer and verifier for SectionsOp
920 //===----------------------------------------------------------------------===//
921 
922 /// Parses an OpenMP Sections operation
923 ///
924 /// sections ::= `omp.sections` clause-list
925 /// clause-list ::= clause clause-list | empty
926 /// clause ::= private | firstprivate | lastprivate | reduction | allocate |
927 ///            nowait
928 static ParseResult parseSectionsOp(OpAsmParser &parser,
929                                    OperationState &result) {
930 
931   SmallVector<ClauseType> clauses = {privateClause,     firstprivateClause,
932                                      lastprivateClause, reductionClause,
933                                      allocateClause,    nowaitClause};
934 
935   SmallVector<int> segments;
936 
937   if (failed(parseClauses(parser, result, clauses, segments)))
938     return failure();
939 
940   result.addAttribute("operand_segment_sizes",
941                       parser.getBuilder().getI32VectorAttr(segments));
942 
943   // Now parse the body.
944   Region *body = result.addRegion();
945   if (parser.parseRegion(*body))
946     return failure();
947   return success();
948 }
949 
950 static void printSectionsOp(OpAsmPrinter &p, SectionsOp op) {
951   p << " ";
952   printDataVars(p, op.private_vars(), "private");
953   printDataVars(p, op.firstprivate_vars(), "firstprivate");
954   printDataVars(p, op.lastprivate_vars(), "lastprivate");
955 
956   if (!op.reduction_vars().empty())
957     printReductionVarList(p, op.reductions(), op.reduction_vars());
958 
959   if (!op.allocate_vars().empty())
960     printAllocateAndAllocator(p, op.allocate_vars(), op.allocators_vars());
961 
962   if (op.nowait())
963     p << "nowait ";
964 
965   p.printRegion(op.region());
966 }
967 
968 static LogicalResult verifySectionsOp(SectionsOp op) {
969 
970   // A list item may not appear in more than one clause on the same directive,
971   // except that it may be specified in both firstprivate and lastprivate
972   // clauses.
973   for (auto var : op.private_vars()) {
974     if (llvm::is_contained(op.firstprivate_vars(), var))
975       return op.emitOpError()
976              << "operand used in both private and firstprivate clauses";
977     if (llvm::is_contained(op.lastprivate_vars(), var))
978       return op.emitOpError()
979              << "operand used in both private and lastprivate clauses";
980   }
981 
982   if (op.allocate_vars().size() != op.allocators_vars().size())
983     return op.emitError(
984         "expected equal sizes for allocate and allocator variables");
985 
986   for (auto &inst : *op.region().begin()) {
987     if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst)))
988       op.emitOpError()
989           << "expected omp.section op or terminator op inside region";
990   }
991 
992   return verifyReductionVarList(op, op.reductions(), op.reduction_vars());
993 }
994 
995 /// Parses an OpenMP Workshare Loop operation
996 ///
997 /// wsloop ::= `omp.wsloop` loop-control clause-list
998 /// loop-control ::= `(` ssa-id-list `)` `:` type `=`  loop-bounds
999 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` inclusive? steps
1000 /// steps := `step` `(`ssa-id-list`)`
1001 /// clause-list ::= clause clause-list | empty
1002 /// clause ::= private | firstprivate | lastprivate | linear | schedule |
1003 //             collapse | nowait | ordered | order | reduction
1004 static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) {
1005 
1006   // Parse an opening `(` followed by induction variables followed by `)`
1007   SmallVector<OpAsmParser::OperandType> ivs;
1008   if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
1009                                      OpAsmParser::Delimiter::Paren))
1010     return failure();
1011 
1012   int numIVs = static_cast<int>(ivs.size());
1013   Type loopVarType;
1014   if (parser.parseColonType(loopVarType))
1015     return failure();
1016 
1017   // Parse loop bounds.
1018   SmallVector<OpAsmParser::OperandType> lower;
1019   if (parser.parseEqual() ||
1020       parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) ||
1021       parser.resolveOperands(lower, loopVarType, result.operands))
1022     return failure();
1023 
1024   SmallVector<OpAsmParser::OperandType> upper;
1025   if (parser.parseKeyword("to") ||
1026       parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) ||
1027       parser.resolveOperands(upper, loopVarType, result.operands))
1028     return failure();
1029 
1030   if (succeeded(parser.parseOptionalKeyword("inclusive"))) {
1031     auto attr = UnitAttr::get(parser.getBuilder().getContext());
1032     result.addAttribute("inclusive", attr);
1033   }
1034 
1035   // Parse step values.
1036   SmallVector<OpAsmParser::OperandType> steps;
1037   if (parser.parseKeyword("step") ||
1038       parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) ||
1039       parser.resolveOperands(steps, loopVarType, result.operands))
1040     return failure();
1041 
1042   SmallVector<ClauseType> clauses = {
1043       privateClause,   firstprivateClause, lastprivateClause, linearClause,
1044       reductionClause, collapseClause,     orderClause,       orderedClause,
1045       nowaitClause,    scheduleClause};
1046   SmallVector<int> segments{numIVs, numIVs, numIVs};
1047   if (failed(parseClauses(parser, result, clauses, segments)))
1048     return failure();
1049 
1050   result.addAttribute("operand_segment_sizes",
1051                       parser.getBuilder().getI32VectorAttr(segments));
1052 
1053   // Now parse the body.
1054   Region *body = result.addRegion();
1055   SmallVector<Type> ivTypes(numIVs, loopVarType);
1056   SmallVector<OpAsmParser::OperandType> blockArgs(ivs);
1057   if (parser.parseRegion(*body, blockArgs, ivTypes))
1058     return failure();
1059   return success();
1060 }
1061 
1062 static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) {
1063   auto args = op.getRegion().front().getArguments();
1064   p << " (" << args << ") : " << args[0].getType() << " = (" << op.lowerBound()
1065     << ") to (" << op.upperBound() << ") ";
1066   if (op.inclusive()) {
1067     p << "inclusive ";
1068   }
1069   p << "step (" << op.step() << ") ";
1070 
1071   printDataVars(p, op.private_vars(), "private");
1072   printDataVars(p, op.firstprivate_vars(), "firstprivate");
1073   printDataVars(p, op.lastprivate_vars(), "lastprivate");
1074 
1075   if (op.linear_vars().size())
1076     printLinearClause(p, op.linear_vars(), op.linear_step_vars());
1077 
1078   if (auto sched = op.schedule_val())
1079     printScheduleClause(p, sched.getValue(), op.schedule_modifier(),
1080                         op.simd_modifier(), op.schedule_chunk_var());
1081 
1082   if (auto collapse = op.collapse_val())
1083     p << "collapse(" << collapse << ") ";
1084 
1085   if (op.nowait())
1086     p << "nowait ";
1087 
1088   if (auto ordered = op.ordered_val())
1089     p << "ordered(" << ordered << ") ";
1090 
1091   if (auto order = op.order_val())
1092     p << "order(" << order << ") ";
1093 
1094   if (!op.reduction_vars().empty())
1095     printReductionVarList(p, op.reductions(), op.reduction_vars());
1096 
1097   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
1098 }
1099 
1100 //===----------------------------------------------------------------------===//
1101 // ReductionOp
1102 //===----------------------------------------------------------------------===//
1103 
1104 static ParseResult parseAtomicReductionRegion(OpAsmParser &parser,
1105                                               Region &region) {
1106   if (parser.parseOptionalKeyword("atomic"))
1107     return success();
1108   return parser.parseRegion(region);
1109 }
1110 
1111 static void printAtomicReductionRegion(OpAsmPrinter &printer,
1112                                        ReductionDeclareOp op, Region &region) {
1113   if (region.empty())
1114     return;
1115   printer << "atomic ";
1116   printer.printRegion(region);
1117 }
1118 
1119 static LogicalResult verifyReductionDeclareOp(ReductionDeclareOp op) {
1120   if (op.initializerRegion().empty())
1121     return op.emitOpError() << "expects non-empty initializer region";
1122   Block &initializerEntryBlock = op.initializerRegion().front();
1123   if (initializerEntryBlock.getNumArguments() != 1 ||
1124       initializerEntryBlock.getArgument(0).getType() != op.type()) {
1125     return op.emitOpError() << "expects initializer region with one argument "
1126                                "of the reduction type";
1127   }
1128 
1129   for (YieldOp yieldOp : op.initializerRegion().getOps<YieldOp>()) {
1130     if (yieldOp.results().size() != 1 ||
1131         yieldOp.results().getTypes()[0] != op.type())
1132       return op.emitOpError() << "expects initializer region to yield a value "
1133                                  "of the reduction type";
1134   }
1135 
1136   if (op.reductionRegion().empty())
1137     return op.emitOpError() << "expects non-empty reduction region";
1138   Block &reductionEntryBlock = op.reductionRegion().front();
1139   if (reductionEntryBlock.getNumArguments() != 2 ||
1140       reductionEntryBlock.getArgumentTypes()[0] !=
1141           reductionEntryBlock.getArgumentTypes()[1] ||
1142       reductionEntryBlock.getArgumentTypes()[0] != op.type())
1143     return op.emitOpError() << "expects reduction region with two arguments of "
1144                                "the reduction type";
1145   for (YieldOp yieldOp : op.reductionRegion().getOps<YieldOp>()) {
1146     if (yieldOp.results().size() != 1 ||
1147         yieldOp.results().getTypes()[0] != op.type())
1148       return op.emitOpError() << "expects reduction region to yield a value "
1149                                  "of the reduction type";
1150   }
1151 
1152   if (op.atomicReductionRegion().empty())
1153     return success();
1154 
1155   Block &atomicReductionEntryBlock = op.atomicReductionRegion().front();
1156   if (atomicReductionEntryBlock.getNumArguments() != 2 ||
1157       atomicReductionEntryBlock.getArgumentTypes()[0] !=
1158           atomicReductionEntryBlock.getArgumentTypes()[1])
1159     return op.emitOpError() << "expects atomic reduction region with two "
1160                                "arguments of the same type";
1161   auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0]
1162                      .dyn_cast<PointerLikeType>();
1163   if (!ptrType || ptrType.getElementType() != op.type())
1164     return op.emitOpError() << "expects atomic reduction region arguments to "
1165                                "be accumulators containing the reduction type";
1166   return success();
1167 }
1168 
1169 static LogicalResult verifyReductionOp(ReductionOp op) {
1170   // TODO: generalize this to an op interface when there is more than one op
1171   // that supports reductions.
1172   auto container = op->getParentOfType<WsLoopOp>();
1173   for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i)
1174     if (container.reduction_vars()[i] == op.accumulator())
1175       return success();
1176 
1177   return op.emitOpError() << "the accumulator is not used by the parent";
1178 }
1179 
1180 //===----------------------------------------------------------------------===//
1181 // WsLoopOp
1182 //===----------------------------------------------------------------------===//
1183 
1184 void WsLoopOp::build(OpBuilder &builder, OperationState &state,
1185                      ValueRange lowerBound, ValueRange upperBound,
1186                      ValueRange step, ArrayRef<NamedAttribute> attributes) {
1187   build(builder, state, TypeRange(), lowerBound, upperBound, step,
1188         /*private_vars=*/ValueRange(),
1189         /*firstprivate_vars=*/ValueRange(), /*lastprivate_vars=*/ValueRange(),
1190         /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
1191         /*reduction_vars=*/ValueRange(), /*schedule_val=*/nullptr,
1192         /*schedule_chunk_var=*/nullptr, /*collapse_val=*/nullptr,
1193         /*nowait=*/nullptr, /*ordered_val=*/nullptr, /*order_val=*/nullptr,
1194         /*inclusive=*/nullptr, /*buildBody=*/false);
1195   state.addAttributes(attributes);
1196 }
1197 
1198 void WsLoopOp::build(OpBuilder &, OperationState &state, TypeRange resultTypes,
1199                      ValueRange operands, ArrayRef<NamedAttribute> attributes) {
1200   state.addOperands(operands);
1201   state.addAttributes(attributes);
1202   (void)state.addRegion();
1203   assert(resultTypes.empty() && "mismatched number of return types");
1204   state.addTypes(resultTypes);
1205 }
1206 
1207 void WsLoopOp::build(OpBuilder &builder, OperationState &result,
1208                      TypeRange typeRange, ValueRange lowerBounds,
1209                      ValueRange upperBounds, ValueRange steps,
1210                      ValueRange privateVars, ValueRange firstprivateVars,
1211                      ValueRange lastprivateVars, ValueRange linearVars,
1212                      ValueRange linearStepVars, ValueRange reductionVars,
1213                      StringAttr scheduleVal, Value scheduleChunkVar,
1214                      IntegerAttr collapseVal, UnitAttr nowait,
1215                      IntegerAttr orderedVal, StringAttr orderVal,
1216                      UnitAttr inclusive, bool buildBody) {
1217   result.addOperands(lowerBounds);
1218   result.addOperands(upperBounds);
1219   result.addOperands(steps);
1220   result.addOperands(privateVars);
1221   result.addOperands(firstprivateVars);
1222   result.addOperands(linearVars);
1223   result.addOperands(linearStepVars);
1224   if (scheduleChunkVar)
1225     result.addOperands(scheduleChunkVar);
1226 
1227   if (scheduleVal)
1228     result.addAttribute("schedule_val", scheduleVal);
1229   if (collapseVal)
1230     result.addAttribute("collapse_val", collapseVal);
1231   if (nowait)
1232     result.addAttribute("nowait", nowait);
1233   if (orderedVal)
1234     result.addAttribute("ordered_val", orderedVal);
1235   if (orderVal)
1236     result.addAttribute("order", orderVal);
1237   if (inclusive)
1238     result.addAttribute("inclusive", inclusive);
1239   result.addAttribute(
1240       WsLoopOp::getOperandSegmentSizeAttr(),
1241       builder.getI32VectorAttr(
1242           {static_cast<int32_t>(lowerBounds.size()),
1243            static_cast<int32_t>(upperBounds.size()),
1244            static_cast<int32_t>(steps.size()),
1245            static_cast<int32_t>(privateVars.size()),
1246            static_cast<int32_t>(firstprivateVars.size()),
1247            static_cast<int32_t>(lastprivateVars.size()),
1248            static_cast<int32_t>(linearVars.size()),
1249            static_cast<int32_t>(linearStepVars.size()),
1250            static_cast<int32_t>(reductionVars.size()),
1251            static_cast<int32_t>(scheduleChunkVar != nullptr ? 1 : 0)}));
1252 
1253   Region *bodyRegion = result.addRegion();
1254   if (buildBody) {
1255     OpBuilder::InsertionGuard guard(builder);
1256     unsigned numIVs = steps.size();
1257     SmallVector<Type, 8> argTypes(numIVs, steps.getType().front());
1258     builder.createBlock(bodyRegion, {}, argTypes);
1259   }
1260 }
1261 
1262 static LogicalResult verifyWsLoopOp(WsLoopOp op) {
1263   return verifyReductionVarList(op, op.reductions(), op.reduction_vars());
1264 }
1265 
1266 //===----------------------------------------------------------------------===//
1267 // Verifier for critical construct (2.17.1)
1268 //===----------------------------------------------------------------------===//
1269 
1270 static LogicalResult verifyCriticalDeclareOp(CriticalDeclareOp op) {
1271   return verifySynchronizationHint(op, op.hint());
1272 }
1273 
1274 static LogicalResult verifyCriticalOp(CriticalOp op) {
1275 
1276   if (op.nameAttr()) {
1277     auto symbolRef = op.nameAttr().cast<SymbolRefAttr>();
1278     auto decl =
1279         SymbolTable::lookupNearestSymbolFrom<CriticalDeclareOp>(op, symbolRef);
1280     if (!decl) {
1281       return op.emitOpError() << "expected symbol reference " << symbolRef
1282                               << " to point to a critical declaration";
1283     }
1284   }
1285 
1286   return success();
1287 }
1288 
1289 //===----------------------------------------------------------------------===//
1290 // Verifier for ordered construct
1291 //===----------------------------------------------------------------------===//
1292 
1293 static LogicalResult verifyOrderedOp(OrderedOp op) {
1294   auto container = op->getParentOfType<WsLoopOp>();
1295   if (!container || !container.ordered_valAttr() ||
1296       container.ordered_valAttr().getInt() == 0)
1297     return op.emitOpError() << "ordered depend directive must be closely "
1298                             << "nested inside a worksharing-loop with ordered "
1299                             << "clause with parameter present";
1300 
1301   if (container.ordered_valAttr().getInt() !=
1302       (int64_t)op.num_loops_val().getValue())
1303     return op.emitOpError() << "number of variables in depend clause does not "
1304                             << "match number of iteration variables in the "
1305                             << "doacross loop";
1306 
1307   return success();
1308 }
1309 
1310 static LogicalResult verifyOrderedRegionOp(OrderedRegionOp op) {
1311   // TODO: The code generation for ordered simd directive is not supported yet.
1312   if (op.simd())
1313     return failure();
1314 
1315   if (auto container = op->getParentOfType<WsLoopOp>()) {
1316     if (!container.ordered_valAttr() ||
1317         container.ordered_valAttr().getInt() != 0)
1318       return op.emitOpError() << "ordered region must be closely nested inside "
1319                               << "a worksharing-loop region with an ordered "
1320                               << "clause without parameter present";
1321   }
1322 
1323   return success();
1324 }
1325 
1326 //===----------------------------------------------------------------------===//
1327 // AtomicReadOp
1328 //===----------------------------------------------------------------------===//
1329 
1330 /// Parser for AtomicReadOp
1331 ///
1332 /// operation ::= `omp.atomic.read` atomic-clause-list address `->` result-type
1333 /// address ::= operand `:` type
1334 static ParseResult parseAtomicReadOp(OpAsmParser &parser,
1335                                      OperationState &result) {
1336   OpAsmParser::OperandType address;
1337   Type addressType;
1338   SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
1339   SmallVector<int> segments;
1340 
1341   if (parser.parseOperand(address) ||
1342       parseClauses(parser, result, clauses, segments) ||
1343       parser.parseColonType(addressType) ||
1344       parser.resolveOperand(address, addressType, result.operands))
1345     return failure();
1346 
1347   SmallVector<Type> resultType;
1348   if (parser.parseArrowTypeList(resultType))
1349     return failure();
1350   result.addTypes(resultType);
1351   return success();
1352 }
1353 
1354 /// Printer for AtomicReadOp
1355 static void printAtomicReadOp(OpAsmPrinter &p, AtomicReadOp op) {
1356   p << " " << op.address() << " ";
1357   if (op.memory_order())
1358     p << "memory_order(" << op.memory_order().getValue() << ") ";
1359   if (op.hintAttr())
1360     printSynchronizationHint(p << " ", op, op.hintAttr());
1361   p << ": " << op.address().getType() << " -> " << op.getType();
1362   return;
1363 }
1364 
1365 /// Verifier for AtomicReadOp
1366 static LogicalResult verifyAtomicReadOp(AtomicReadOp op) {
1367   if (op.memory_order()) {
1368     StringRef memOrder = op.memory_order().getValue();
1369     if (memOrder.equals("acq_rel") || memOrder.equals("release"))
1370       return op.emitError(
1371           "memory-order must not be acq_rel or release for atomic reads");
1372   }
1373   return verifySynchronizationHint(op, op.hint());
1374 }
1375 
1376 //===----------------------------------------------------------------------===//
1377 // AtomicWriteOp
1378 //===----------------------------------------------------------------------===//
1379 
1380 /// Parser for AtomicWriteOp
1381 ///
1382 /// operation ::= `omp.atomic.write` atomic-clause-list operands
1383 /// operands ::= address `,` value
1384 /// address ::= operand `:` type
1385 /// value ::= operand `:` type
1386 static ParseResult parseAtomicWriteOp(OpAsmParser &parser,
1387                                       OperationState &result) {
1388   OpAsmParser::OperandType address, value;
1389   Type addrType, valueType;
1390   SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
1391   SmallVector<int> segments;
1392 
1393   if (parser.parseOperand(address) || parser.parseComma() ||
1394       parser.parseOperand(value) ||
1395       parseClauses(parser, result, clauses, segments) ||
1396       parser.parseColonType(addrType) || parser.parseComma() ||
1397       parser.parseType(valueType) ||
1398       parser.resolveOperand(address, addrType, result.operands) ||
1399       parser.resolveOperand(value, valueType, result.operands))
1400     return failure();
1401   return success();
1402 }
1403 
1404 /// Printer for AtomicWriteOp
1405 static void printAtomicWriteOp(OpAsmPrinter &p, AtomicWriteOp op) {
1406   p << " " << op.address() << ", " << op.value() << " ";
1407   if (op.memory_order())
1408     p << "memory_order(" << op.memory_order() << ") ";
1409   if (op.hintAttr())
1410     printSynchronizationHint(p, op, op.hintAttr());
1411   p << ": " << op.address().getType() << ", " << op.value().getType();
1412   return;
1413 }
1414 
1415 /// Verifier for AtomicWriteOp
1416 static LogicalResult verifyAtomicWriteOp(AtomicWriteOp op) {
1417   if (op.memory_order()) {
1418     StringRef memoryOrder = op.memory_order().getValue();
1419     if (memoryOrder.equals("acq_rel") || memoryOrder.equals("acquire"))
1420       return op.emitError(
1421           "memory-order must not be acq_rel or acquire for atomic writes");
1422   }
1423   return verifySynchronizationHint(op, op.hint());
1424 }
1425 
1426 #define GET_OP_CLASSES
1427 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
1428