1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the OpenMP dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/DialectImplementation.h"
18 #include "mlir/IR/OpImplementation.h"
19 #include "mlir/IR/OperationSupport.h"
20 
21 #include "llvm/ADT/BitVector.h"
22 #include "llvm/ADT/SmallString.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/ADT/StringSwitch.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include <cstddef>
28 
29 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
30 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
31 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
32 
33 using namespace mlir;
34 using namespace mlir::omp;
35 
36 namespace {
37 /// Model for pointer-like types that already provide a `getElementType` method.
38 template <typename T>
39 struct PointerLikeModel
40     : public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> {
41   Type getElementType(Type pointer) const {
42     return pointer.cast<T>().getElementType();
43   }
44 };
45 } // namespace
46 
47 void OpenMPDialect::initialize() {
48   addOperations<
49 #define GET_OP_LIST
50 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
51       >();
52   addAttributes<
53 #define GET_ATTRDEF_LIST
54 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
55       >();
56 
57   LLVM::LLVMPointerType::attachInterface<
58       PointerLikeModel<LLVM::LLVMPointerType>>(*getContext());
59   MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext());
60 }
61 
62 //===----------------------------------------------------------------------===//
63 // ParallelOp
64 //===----------------------------------------------------------------------===//
65 
66 void ParallelOp::build(OpBuilder &builder, OperationState &state,
67                        ArrayRef<NamedAttribute> attributes) {
68   ParallelOp::build(
69       builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
70       /*default_val=*/nullptr, /*private_vars=*/ValueRange(),
71       /*firstprivate_vars=*/ValueRange(), /*shared_vars=*/ValueRange(),
72       /*copyin_vars=*/ValueRange(), /*allocate_vars=*/ValueRange(),
73       /*allocators_vars=*/ValueRange(), /*proc_bind_val=*/nullptr);
74   state.addAttributes(attributes);
75 }
76 
77 //===----------------------------------------------------------------------===//
78 // Parser and printer for Operand and type list
79 //===----------------------------------------------------------------------===//
80 
81 /// Parse a list of operands with types.
82 ///
83 /// operand-and-type-list ::= `(` ssa-id-and-type-list `)`
84 /// ssa-id-and-type-list ::= ssa-id-and-type |
85 ///                          ssa-id-and-type `,` ssa-id-and-type-list
86 /// ssa-id-and-type ::= ssa-id `:` type
87 static ParseResult
88 parseOperandAndTypeList(OpAsmParser &parser,
89                         SmallVectorImpl<OpAsmParser::OperandType> &operands,
90                         SmallVectorImpl<Type> &types) {
91   return parser.parseCommaSeparatedList(
92       OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
93         OpAsmParser::OperandType operand;
94         Type type;
95         if (parser.parseOperand(operand) || parser.parseColonType(type))
96           return failure();
97         operands.push_back(operand);
98         types.push_back(type);
99         return success();
100       });
101 }
102 
103 /// Print an operand and type list with parentheses
104 static void printOperandAndTypeList(OpAsmPrinter &p, OperandRange operands) {
105   p << "(";
106   llvm::interleaveComma(
107       operands, p, [&](const Value &v) { p << v << " : " << v.getType(); });
108   p << ") ";
109 }
110 
111 /// Print data variables corresponding to a data-sharing clause `name`
112 static void printDataVars(OpAsmPrinter &p, OperandRange operands,
113                           StringRef name) {
114   if (!operands.empty()) {
115     p << name;
116     printOperandAndTypeList(p, operands);
117   }
118 }
119 
120 //===----------------------------------------------------------------------===//
121 // Parser and printer for Allocate Clause
122 //===----------------------------------------------------------------------===//
123 
124 /// Parse an allocate clause with allocators and a list of operands with types.
125 ///
126 /// allocate ::= `allocate` `(` allocate-operand-list `)`
127 /// allocate-operand-list :: = allocate-operand |
128 ///                            allocator-operand `,` allocate-operand-list
129 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
130 /// ssa-id-and-type ::= ssa-id `:` type
131 static ParseResult parseAllocateAndAllocator(
132     OpAsmParser &parser,
133     SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate,
134     SmallVectorImpl<Type> &typesAllocate,
135     SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator,
136     SmallVectorImpl<Type> &typesAllocator) {
137 
138   return parser.parseCommaSeparatedList(
139       OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
140         OpAsmParser::OperandType operand;
141         Type type;
142         if (parser.parseOperand(operand) || parser.parseColonType(type))
143           return failure();
144         operandsAllocator.push_back(operand);
145         typesAllocator.push_back(type);
146         if (parser.parseArrow())
147           return failure();
148         if (parser.parseOperand(operand) || parser.parseColonType(type))
149           return failure();
150 
151         operandsAllocate.push_back(operand);
152         typesAllocate.push_back(type);
153         return success();
154       });
155 }
156 
157 /// Print allocate clause
158 static void printAllocateAndAllocator(OpAsmPrinter &p,
159                                       OperandRange varsAllocate,
160                                       OperandRange varsAllocator) {
161   p << "allocate(";
162   for (unsigned i = 0; i < varsAllocate.size(); ++i) {
163     std::string separator = i == varsAllocate.size() - 1 ? ") " : ", ";
164     p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> ";
165     p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator;
166   }
167 }
168 
169 static LogicalResult verifyParallelOp(ParallelOp op) {
170   if (op.allocate_vars().size() != op.allocators_vars().size())
171     return op.emitError(
172         "expected equal sizes for allocate and allocator variables");
173   return success();
174 }
175 
176 static void printParallelOp(OpAsmPrinter &p, ParallelOp op) {
177   p << " ";
178   if (auto ifCond = op.if_expr_var())
179     p << "if(" << ifCond << " : " << ifCond.getType() << ") ";
180 
181   if (auto threads = op.num_threads_var())
182     p << "num_threads(" << threads << " : " << threads.getType() << ") ";
183 
184   printDataVars(p, op.private_vars(), "private");
185   printDataVars(p, op.firstprivate_vars(), "firstprivate");
186   printDataVars(p, op.shared_vars(), "shared");
187   printDataVars(p, op.copyin_vars(), "copyin");
188 
189   if (!op.allocate_vars().empty())
190     printAllocateAndAllocator(p, op.allocate_vars(), op.allocators_vars());
191 
192   if (auto def = op.default_val())
193     p << "default(" << stringifyClauseDefault(*def).drop_front(3) << ") ";
194 
195   if (auto bind = op.proc_bind_val())
196     p << "proc_bind(" << stringifyClauseProcBindKind(*bind) << ") ";
197 
198   p << ' ';
199   p.printRegion(op.getRegion());
200 }
201 
202 //===----------------------------------------------------------------------===//
203 // Parser and printer for Linear Clause
204 //===----------------------------------------------------------------------===//
205 
206 /// linear ::= `linear` `(` linear-list `)`
207 /// linear-list := linear-val | linear-val linear-list
208 /// linear-val := ssa-id-and-type `=` ssa-id-and-type
209 static ParseResult
210 parseLinearClause(OpAsmParser &parser,
211                   SmallVectorImpl<OpAsmParser::OperandType> &vars,
212                   SmallVectorImpl<Type> &types,
213                   SmallVectorImpl<OpAsmParser::OperandType> &stepVars) {
214   if (parser.parseLParen())
215     return failure();
216 
217   do {
218     OpAsmParser::OperandType var;
219     Type type;
220     OpAsmParser::OperandType stepVar;
221     if (parser.parseOperand(var) || parser.parseEqual() ||
222         parser.parseOperand(stepVar) || parser.parseColonType(type))
223       return failure();
224 
225     vars.push_back(var);
226     types.push_back(type);
227     stepVars.push_back(stepVar);
228   } while (succeeded(parser.parseOptionalComma()));
229 
230   if (parser.parseRParen())
231     return failure();
232 
233   return success();
234 }
235 
236 /// Print Linear Clause
237 static void printLinearClause(OpAsmPrinter &p, OperandRange linearVars,
238                               OperandRange linearStepVars) {
239   size_t linearVarsSize = linearVars.size();
240   p << "linear(";
241   for (unsigned i = 0; i < linearVarsSize; ++i) {
242     std::string separator = i == linearVarsSize - 1 ? ") " : ", ";
243     p << linearVars[i];
244     if (linearStepVars.size() > i)
245       p << " = " << linearStepVars[i];
246     p << " : " << linearVars[i].getType() << separator;
247   }
248 }
249 
250 //===----------------------------------------------------------------------===//
251 // Parser and printer for Schedule Clause
252 //===----------------------------------------------------------------------===//
253 
254 static ParseResult
255 verifyScheduleModifiers(OpAsmParser &parser,
256                         SmallVectorImpl<SmallString<12>> &modifiers) {
257   if (modifiers.size() > 2)
258     return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)";
259   for (const auto &mod : modifiers) {
260     // Translate the string. If it has no value, then it was not a valid
261     // modifier!
262     auto symbol = symbolizeScheduleModifier(mod);
263     if (!symbol.hasValue())
264       return parser.emitError(parser.getNameLoc())
265              << " unknown modifier type: " << mod;
266   }
267 
268   // If we have one modifier that is "simd", then stick a "none" modiifer in
269   // index 0.
270   if (modifiers.size() == 1) {
271     if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
272       modifiers.push_back(modifiers[0]);
273       modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
274     }
275   } else if (modifiers.size() == 2) {
276     // If there are two modifier:
277     // First modifier should not be simd, second one should be simd
278     if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
279         symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
280       return parser.emitError(parser.getNameLoc())
281              << " incorrect modifier order";
282   }
283   return success();
284 }
285 
286 /// schedule ::= `schedule` `(` sched-list `)`
287 /// sched-list ::= sched-val | sched-val sched-list |
288 ///                sched-val `,` sched-modifier
289 /// sched-val ::= sched-with-chunk | sched-wo-chunk
290 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
291 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
292 /// sched-wo-chunk ::=  `auto` | `runtime`
293 /// sched-modifier ::=  sched-mod-val | sched-mod-val `,` sched-mod-val
294 /// sched-mod-val ::=  `monotonic` | `nonmonotonic` | `simd` | `none`
295 static ParseResult
296 parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule,
297                     SmallVectorImpl<SmallString<12>> &modifiers,
298                     Optional<OpAsmParser::OperandType> &chunkSize) {
299   if (parser.parseLParen())
300     return failure();
301 
302   StringRef keyword;
303   if (parser.parseKeyword(&keyword))
304     return failure();
305 
306   schedule = keyword;
307   if (keyword == "static" || keyword == "dynamic" || keyword == "guided") {
308     if (succeeded(parser.parseOptionalEqual())) {
309       chunkSize = OpAsmParser::OperandType{};
310       if (parser.parseOperand(*chunkSize))
311         return failure();
312     } else {
313       chunkSize = llvm::NoneType::None;
314     }
315   } else if (keyword == "auto" || keyword == "runtime") {
316     chunkSize = llvm::NoneType::None;
317   } else {
318     return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
319   }
320 
321   // If there is a comma, we have one or more modifiers..
322   while (succeeded(parser.parseOptionalComma())) {
323     StringRef mod;
324     if (parser.parseKeyword(&mod))
325       return failure();
326     modifiers.push_back(mod);
327   }
328 
329   if (parser.parseRParen())
330     return failure();
331 
332   if (verifyScheduleModifiers(parser, modifiers))
333     return failure();
334 
335   return success();
336 }
337 
338 /// Print schedule clause
339 static void printScheduleClause(OpAsmPrinter &p, ClauseScheduleKind sched,
340                                 Optional<ScheduleModifier> modifier, bool simd,
341                                 Value scheduleChunkVar) {
342   p << "schedule(" << stringifyClauseScheduleKind(sched).lower();
343   if (scheduleChunkVar)
344     p << " = " << scheduleChunkVar;
345   if (modifier)
346     p << ", " << stringifyScheduleModifier(*modifier);
347   if (simd)
348     p << ", simd";
349   p << ") ";
350 }
351 
352 //===----------------------------------------------------------------------===//
353 // Parser, printer and verifier for ReductionVarList
354 //===----------------------------------------------------------------------===//
355 
356 /// reduction ::= `reduction` `(` reduction-entry-list `)`
357 /// reduction-entry-list ::= reduction-entry
358 ///                        | reduction-entry-list `,` reduction-entry
359 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type
360 static ParseResult
361 parseReductionVarList(OpAsmParser &parser,
362                       SmallVectorImpl<SymbolRefAttr> &symbols,
363                       SmallVectorImpl<OpAsmParser::OperandType> &operands,
364                       SmallVectorImpl<Type> &types) {
365   if (failed(parser.parseLParen()))
366     return failure();
367 
368   do {
369     if (parser.parseAttribute(symbols.emplace_back()) || parser.parseArrow() ||
370         parser.parseOperand(operands.emplace_back()) ||
371         parser.parseColonType(types.emplace_back()))
372       return failure();
373   } while (succeeded(parser.parseOptionalComma()));
374   return parser.parseRParen();
375 }
376 
377 /// Print Reduction clause
378 static void printReductionVarList(OpAsmPrinter &p,
379                                   Optional<ArrayAttr> reductions,
380                                   OperandRange reductionVars) {
381   p << "reduction(";
382   for (unsigned i = 0, e = reductions->size(); i < e; ++i) {
383     if (i != 0)
384       p << ", ";
385     p << (*reductions)[i] << " -> " << reductionVars[i] << " : "
386       << reductionVars[i].getType();
387   }
388   p << ") ";
389 }
390 
391 /// Verifies Reduction Clause
392 static LogicalResult verifyReductionVarList(Operation *op,
393                                             Optional<ArrayAttr> reductions,
394                                             OperandRange reductionVars) {
395   if (!reductionVars.empty()) {
396     if (!reductions || reductions->size() != reductionVars.size())
397       return op->emitOpError()
398              << "expected as many reduction symbol references "
399                 "as reduction variables";
400   } else {
401     if (reductions)
402       return op->emitOpError() << "unexpected reduction symbol references";
403     return success();
404   }
405 
406   DenseSet<Value> accumulators;
407   for (auto args : llvm::zip(reductionVars, *reductions)) {
408     Value accum = std::get<0>(args);
409 
410     if (!accumulators.insert(accum).second)
411       return op->emitOpError() << "accumulator variable used more than once";
412 
413     Type varType = accum.getType().cast<PointerLikeType>();
414     auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>();
415     auto decl =
416         SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef);
417     if (!decl)
418       return op->emitOpError() << "expected symbol reference " << symbolRef
419                                << " to point to a reduction declaration";
420 
421     if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
422       return op->emitOpError()
423              << "expected accumulator (" << varType
424              << ") to be the same type as reduction declaration ("
425              << decl.getAccumulatorType() << ")";
426   }
427 
428   return success();
429 }
430 
431 //===----------------------------------------------------------------------===//
432 // Parser, printer and verifier for Synchronization Hint (2.17.12)
433 //===----------------------------------------------------------------------===//
434 
435 /// Parses a Synchronization Hint clause. The value of hint is an integer
436 /// which is a combination of different hints from `omp_sync_hint_t`.
437 ///
438 /// hint-clause = `hint` `(` hint-value `)`
439 static ParseResult parseSynchronizationHint(OpAsmParser &parser,
440                                             IntegerAttr &hintAttr,
441                                             bool parseKeyword = true) {
442   if (parseKeyword && failed(parser.parseOptionalKeyword("hint"))) {
443     hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
444     return success();
445   }
446 
447   if (failed(parser.parseLParen()))
448     return failure();
449   StringRef hintKeyword;
450   int64_t hint = 0;
451   do {
452     if (failed(parser.parseKeyword(&hintKeyword)))
453       return failure();
454     if (hintKeyword == "uncontended")
455       hint |= 1;
456     else if (hintKeyword == "contended")
457       hint |= 2;
458     else if (hintKeyword == "nonspeculative")
459       hint |= 4;
460     else if (hintKeyword == "speculative")
461       hint |= 8;
462     else
463       return parser.emitError(parser.getCurrentLocation())
464              << hintKeyword << " is not a valid hint";
465   } while (succeeded(parser.parseOptionalComma()));
466   if (failed(parser.parseRParen()))
467     return failure();
468   hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
469   return success();
470 }
471 
472 /// Prints a Synchronization Hint clause
473 static void printSynchronizationHint(OpAsmPrinter &p, Operation *op,
474                                      IntegerAttr hintAttr) {
475   int64_t hint = hintAttr.getInt();
476 
477   if (hint == 0)
478     return;
479 
480   // Helper function to get n-th bit from the right end of `value`
481   auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
482 
483   bool uncontended = bitn(hint, 0);
484   bool contended = bitn(hint, 1);
485   bool nonspeculative = bitn(hint, 2);
486   bool speculative = bitn(hint, 3);
487 
488   SmallVector<StringRef> hints;
489   if (uncontended)
490     hints.push_back("uncontended");
491   if (contended)
492     hints.push_back("contended");
493   if (nonspeculative)
494     hints.push_back("nonspeculative");
495   if (speculative)
496     hints.push_back("speculative");
497 
498   p << "hint(";
499   llvm::interleaveComma(hints, p);
500   p << ") ";
501 }
502 
503 /// Verifies a synchronization hint clause
504 static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
505 
506   // Helper function to get n-th bit from the right end of `value`
507   auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
508 
509   bool uncontended = bitn(hint, 0);
510   bool contended = bitn(hint, 1);
511   bool nonspeculative = bitn(hint, 2);
512   bool speculative = bitn(hint, 3);
513 
514   if (uncontended && contended)
515     return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
516                                 "omp_sync_hint_contended cannot be combined";
517   if (nonspeculative && speculative)
518     return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
519                                 "omp_sync_hint_speculative cannot be combined.";
520   return success();
521 }
522 
523 enum ClauseType {
524   ifClause,
525   numThreadsClause,
526   privateClause,
527   firstprivateClause,
528   lastprivateClause,
529   sharedClause,
530   copyinClause,
531   allocateClause,
532   defaultClause,
533   procBindClause,
534   reductionClause,
535   nowaitClause,
536   linearClause,
537   scheduleClause,
538   collapseClause,
539   orderClause,
540   orderedClause,
541   memoryOrderClause,
542   hintClause,
543   COUNT
544 };
545 
546 //===----------------------------------------------------------------------===//
547 // Parser for Clause List
548 //===----------------------------------------------------------------------===//
549 
550 /// Parse a clause attribute `(` $value `)`.
551 template <typename ClauseAttr>
552 static ParseResult parseClauseAttr(AsmParser &parser, OperationState &state,
553                                    StringRef attrName, StringRef name) {
554   using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
555   StringRef enumStr;
556   llvm::SMLoc loc = parser.getCurrentLocation();
557   if (parser.parseLParen() || parser.parseKeyword(&enumStr) ||
558       parser.parseRParen())
559     return failure();
560   if (Optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
561     auto attr = ClauseAttr::get(parser.getContext(), *enumValue);
562     state.addAttribute(attrName, attr);
563     return success();
564   }
565   return parser.emitError(loc, "invalid ") << name << " kind";
566 }
567 
568 /// Parse a list of clauses. The clauses can appear in any order, but their
569 /// operand segment indices are in the same order that they are passed in the
570 /// `clauses` list. The operand segments are added over the prevSegments
571 
572 /// clause-list ::= clause clause-list | empty
573 /// clause ::= if | num-threads | private | firstprivate | lastprivate |
574 ///            shared | copyin | allocate | default | proc-bind | reduction |
575 ///            nowait | linear | schedule | collapse | order | ordered |
576 ///            inclusive
577 /// if ::= `if` `(` ssa-id-and-type `)`
578 /// num-threads ::= `num_threads` `(` ssa-id-and-type `)`
579 /// private ::= `private` operand-and-type-list
580 /// firstprivate ::= `firstprivate` operand-and-type-list
581 /// lastprivate ::= `lastprivate` operand-and-type-list
582 /// shared ::= `shared` operand-and-type-list
583 /// copyin ::= `copyin` operand-and-type-list
584 /// allocate ::= `allocate` `(` allocate-operand-list `)`
585 /// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`)
586 /// proc-bind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)`
587 /// reduction ::= `reduction` `(` reduction-entry-list `)`
588 /// nowait ::= `nowait`
589 /// linear ::= `linear` `(` linear-list `)`
590 /// schedule ::= `schedule` `(` sched-list `)`
591 /// collapse ::= `collapse` `(` ssa-id-and-type `)`
592 /// order ::= `order` `(` `concurrent` `)`
593 /// ordered ::= `ordered` `(` ssa-id-and-type `)`
594 /// inclusive ::= `inclusive`
595 ///
596 /// Note that each clause can only appear once in the clase-list.
597 static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
598                                 SmallVectorImpl<ClauseType> &clauses,
599                                 SmallVectorImpl<int> &segments) {
600 
601   // Check done[clause] to see if it has been parsed already
602   llvm::BitVector done(ClauseType::COUNT, false);
603 
604   // See pos[clause] to get position of clause in operand segments
605   SmallVector<int> pos(ClauseType::COUNT, -1);
606 
607   // Stores the last parsed clause keyword
608   StringRef clauseKeyword;
609   StringRef opName = result.name.getStringRef();
610 
611   // Containers for storing operands, types and attributes for various clauses
612   std::pair<OpAsmParser::OperandType, Type> ifCond;
613   std::pair<OpAsmParser::OperandType, Type> numThreads;
614 
615   SmallVector<OpAsmParser::OperandType> privates, firstprivates, lastprivates,
616       shareds, copyins;
617   SmallVector<Type> privateTypes, firstprivateTypes, lastprivateTypes,
618       sharedTypes, copyinTypes;
619 
620   SmallVector<OpAsmParser::OperandType> allocates, allocators;
621   SmallVector<Type> allocateTypes, allocatorTypes;
622 
623   SmallVector<SymbolRefAttr> reductionSymbols;
624   SmallVector<OpAsmParser::OperandType> reductionVars;
625   SmallVector<Type> reductionVarTypes;
626 
627   SmallVector<OpAsmParser::OperandType> linears;
628   SmallVector<Type> linearTypes;
629   SmallVector<OpAsmParser::OperandType> linearSteps;
630 
631   SmallString<8> schedule;
632   SmallVector<SmallString<12>> modifiers;
633   Optional<OpAsmParser::OperandType> scheduleChunkSize;
634 
635   // Compute the position of clauses in operand segments
636   int currPos = 0;
637   for (ClauseType clause : clauses) {
638 
639     // Skip the following clauses - they do not take any position in operand
640     // segments
641     if (clause == defaultClause || clause == procBindClause ||
642         clause == nowaitClause || clause == collapseClause ||
643         clause == orderClause || clause == orderedClause)
644       continue;
645 
646     pos[clause] = currPos++;
647 
648     // For the following clauses, two positions are reserved in the operand
649     // segments
650     if (clause == allocateClause || clause == linearClause)
651       currPos++;
652   }
653 
654   SmallVector<int> clauseSegments(currPos);
655 
656   // Helper function to check if a clause is allowed/repeated or not
657   auto checkAllowed = [&](ClauseType clause) -> ParseResult {
658     if (!llvm::is_contained(clauses, clause))
659       return parser.emitError(parser.getCurrentLocation())
660              << clauseKeyword << " is not a valid clause for the " << opName
661              << " operation";
662     if (done[clause])
663       return parser.emitError(parser.getCurrentLocation())
664              << "at most one " << clauseKeyword << " clause can appear on the "
665              << opName << " operation";
666     done[clause] = true;
667     return success();
668   };
669 
670   while (succeeded(parser.parseOptionalKeyword(&clauseKeyword))) {
671     if (clauseKeyword == "if") {
672       if (checkAllowed(ifClause) || parser.parseLParen() ||
673           parser.parseOperand(ifCond.first) ||
674           parser.parseColonType(ifCond.second) || parser.parseRParen())
675         return failure();
676       clauseSegments[pos[ifClause]] = 1;
677     } else if (clauseKeyword == "num_threads") {
678       if (checkAllowed(numThreadsClause) || parser.parseLParen() ||
679           parser.parseOperand(numThreads.first) ||
680           parser.parseColonType(numThreads.second) || parser.parseRParen())
681         return failure();
682       clauseSegments[pos[numThreadsClause]] = 1;
683     } else if (clauseKeyword == "private") {
684       if (checkAllowed(privateClause) ||
685           parseOperandAndTypeList(parser, privates, privateTypes))
686         return failure();
687       clauseSegments[pos[privateClause]] = privates.size();
688     } else if (clauseKeyword == "firstprivate") {
689       if (checkAllowed(firstprivateClause) ||
690           parseOperandAndTypeList(parser, firstprivates, firstprivateTypes))
691         return failure();
692       clauseSegments[pos[firstprivateClause]] = firstprivates.size();
693     } else if (clauseKeyword == "lastprivate") {
694       if (checkAllowed(lastprivateClause) ||
695           parseOperandAndTypeList(parser, lastprivates, lastprivateTypes))
696         return failure();
697       clauseSegments[pos[lastprivateClause]] = lastprivates.size();
698     } else if (clauseKeyword == "shared") {
699       if (checkAllowed(sharedClause) ||
700           parseOperandAndTypeList(parser, shareds, sharedTypes))
701         return failure();
702       clauseSegments[pos[sharedClause]] = shareds.size();
703     } else if (clauseKeyword == "copyin") {
704       if (checkAllowed(copyinClause) ||
705           parseOperandAndTypeList(parser, copyins, copyinTypes))
706         return failure();
707       clauseSegments[pos[copyinClause]] = copyins.size();
708     } else if (clauseKeyword == "allocate") {
709       if (checkAllowed(allocateClause) ||
710           parseAllocateAndAllocator(parser, allocates, allocateTypes,
711                                     allocators, allocatorTypes))
712         return failure();
713       clauseSegments[pos[allocateClause]] = allocates.size();
714       clauseSegments[pos[allocateClause] + 1] = allocators.size();
715     } else if (clauseKeyword == "default") {
716       StringRef defval;
717       llvm::SMLoc loc = parser.getCurrentLocation();
718       if (checkAllowed(defaultClause) || parser.parseLParen() ||
719           parser.parseKeyword(&defval) || parser.parseRParen())
720         return failure();
721       // The def prefix is required for the attribute as "private" is a keyword
722       // in C++.
723       if (Optional<ClauseDefault> def =
724               symbolizeClauseDefault(("def" + defval).str())) {
725         result.addAttribute("default_val",
726                             ClauseDefaultAttr::get(parser.getContext(), *def));
727       } else {
728         return parser.emitError(loc, "invalid default clause");
729       }
730     } else if (clauseKeyword == "proc_bind") {
731       if (checkAllowed(procBindClause) ||
732           parseClauseAttr<ClauseProcBindKindAttr>(parser, result,
733                                                   "proc_bind_val", "proc bind"))
734         return failure();
735     } else if (clauseKeyword == "reduction") {
736       if (checkAllowed(reductionClause) ||
737           parseReductionVarList(parser, reductionSymbols, reductionVars,
738                                 reductionVarTypes))
739         return failure();
740       clauseSegments[pos[reductionClause]] = reductionVars.size();
741     } else if (clauseKeyword == "nowait") {
742       if (checkAllowed(nowaitClause))
743         return failure();
744       auto attr = UnitAttr::get(parser.getBuilder().getContext());
745       result.addAttribute("nowait", attr);
746     } else if (clauseKeyword == "linear") {
747       if (checkAllowed(linearClause) ||
748           parseLinearClause(parser, linears, linearTypes, linearSteps))
749         return failure();
750       clauseSegments[pos[linearClause]] = linears.size();
751       clauseSegments[pos[linearClause] + 1] = linearSteps.size();
752     } else if (clauseKeyword == "schedule") {
753       if (checkAllowed(scheduleClause) ||
754           parseScheduleClause(parser, schedule, modifiers, scheduleChunkSize))
755         return failure();
756       if (scheduleChunkSize) {
757         clauseSegments[pos[scheduleClause]] = 1;
758       }
759     } else if (clauseKeyword == "collapse") {
760       auto type = parser.getBuilder().getI64Type();
761       mlir::IntegerAttr attr;
762       if (checkAllowed(collapseClause) || parser.parseLParen() ||
763           parser.parseAttribute(attr, type) || parser.parseRParen())
764         return failure();
765       result.addAttribute("collapse_val", attr);
766     } else if (clauseKeyword == "ordered") {
767       mlir::IntegerAttr attr;
768       if (checkAllowed(orderedClause))
769         return failure();
770       if (succeeded(parser.parseOptionalLParen())) {
771         auto type = parser.getBuilder().getI64Type();
772         if (parser.parseAttribute(attr, type) || parser.parseRParen())
773           return failure();
774       } else {
775         // Use 0 to represent no ordered parameter was specified
776         attr = parser.getBuilder().getI64IntegerAttr(0);
777       }
778       result.addAttribute("ordered_val", attr);
779     } else if (clauseKeyword == "order") {
780       if (checkAllowed(orderClause) ||
781           parseClauseAttr<ClauseOrderKindAttr>(parser, result, "order_val",
782                                                "order"))
783         return failure();
784     } else if (clauseKeyword == "memory_order") {
785       if (checkAllowed(memoryOrderClause) ||
786           parseClauseAttr<ClauseMemoryOrderKindAttr>(
787               parser, result, "memory_order", "memory order"))
788         return failure();
789     } else if (clauseKeyword == "hint") {
790       IntegerAttr hint;
791       if (checkAllowed(hintClause) ||
792           parseSynchronizationHint(parser, hint, false))
793         return failure();
794       result.addAttribute("hint", hint);
795     } else {
796       return parser.emitError(parser.getNameLoc())
797              << clauseKeyword << " is not a valid clause";
798     }
799   }
800 
801   // Add if parameter.
802   if (done[ifClause] && clauseSegments[pos[ifClause]] &&
803       failed(
804           parser.resolveOperand(ifCond.first, ifCond.second, result.operands)))
805     return failure();
806 
807   // Add num_threads parameter.
808   if (done[numThreadsClause] && clauseSegments[pos[numThreadsClause]] &&
809       failed(parser.resolveOperand(numThreads.first, numThreads.second,
810                                    result.operands)))
811     return failure();
812 
813   // Add private parameters.
814   if (done[privateClause] && clauseSegments[pos[privateClause]] &&
815       failed(parser.resolveOperands(privates, privateTypes,
816                                     privates[0].location, result.operands)))
817     return failure();
818 
819   // Add firstprivate parameters.
820   if (done[firstprivateClause] && clauseSegments[pos[firstprivateClause]] &&
821       failed(parser.resolveOperands(firstprivates, firstprivateTypes,
822                                     firstprivates[0].location,
823                                     result.operands)))
824     return failure();
825 
826   // Add lastprivate parameters.
827   if (done[lastprivateClause] && clauseSegments[pos[lastprivateClause]] &&
828       failed(parser.resolveOperands(lastprivates, lastprivateTypes,
829                                     lastprivates[0].location, result.operands)))
830     return failure();
831 
832   // Add shared parameters.
833   if (done[sharedClause] && clauseSegments[pos[sharedClause]] &&
834       failed(parser.resolveOperands(shareds, sharedTypes, shareds[0].location,
835                                     result.operands)))
836     return failure();
837 
838   // Add copyin parameters.
839   if (done[copyinClause] && clauseSegments[pos[copyinClause]] &&
840       failed(parser.resolveOperands(copyins, copyinTypes, copyins[0].location,
841                                     result.operands)))
842     return failure();
843 
844   // Add allocate parameters.
845   if (done[allocateClause] && clauseSegments[pos[allocateClause]] &&
846       failed(parser.resolveOperands(allocates, allocateTypes,
847                                     allocates[0].location, result.operands)))
848     return failure();
849 
850   // Add allocator parameters.
851   if (done[allocateClause] && clauseSegments[pos[allocateClause] + 1] &&
852       failed(parser.resolveOperands(allocators, allocatorTypes,
853                                     allocators[0].location, result.operands)))
854     return failure();
855 
856   // Add reduction parameters and symbols
857   if (done[reductionClause] && clauseSegments[pos[reductionClause]]) {
858     if (failed(parser.resolveOperands(reductionVars, reductionVarTypes,
859                                       parser.getNameLoc(), result.operands)))
860       return failure();
861 
862     SmallVector<Attribute> reductions(reductionSymbols.begin(),
863                                       reductionSymbols.end());
864     result.addAttribute("reductions",
865                         parser.getBuilder().getArrayAttr(reductions));
866   }
867 
868   // Add linear parameters
869   if (done[linearClause] && clauseSegments[pos[linearClause]]) {
870     auto linearStepType = parser.getBuilder().getI32Type();
871     SmallVector<Type> linearStepTypes(linearSteps.size(), linearStepType);
872     if (failed(parser.resolveOperands(linears, linearTypes, linears[0].location,
873                                       result.operands)) ||
874         failed(parser.resolveOperands(linearSteps, linearStepTypes,
875                                       linearSteps[0].location,
876                                       result.operands)))
877       return failure();
878   }
879 
880   // Add schedule parameters
881   if (done[scheduleClause] && !schedule.empty()) {
882     schedule[0] = llvm::toUpper(schedule[0]);
883     if (Optional<ClauseScheduleKind> sched =
884             symbolizeClauseScheduleKind(schedule)) {
885       auto attr = ClauseScheduleKindAttr::get(parser.getContext(), *sched);
886       result.addAttribute("schedule_val", attr);
887     } else {
888       return parser.emitError(parser.getCurrentLocation(),
889                               "invalid schedule kind");
890     }
891     if (!modifiers.empty()) {
892       llvm::SMLoc loc = parser.getCurrentLocation();
893       if (Optional<ScheduleModifier> mod =
894               symbolizeScheduleModifier(modifiers[0])) {
895         result.addAttribute(
896             "schedule_modifier",
897             ScheduleModifierAttr::get(parser.getContext(), *mod));
898       } else {
899         return parser.emitError(loc, "invalid schedule modifier");
900       }
901       // Only SIMD attribute is allowed here!
902       if (modifiers.size() > 1) {
903         assert(symbolizeScheduleModifier(modifiers[1]) ==
904                ScheduleModifier::simd);
905         auto attr = UnitAttr::get(parser.getBuilder().getContext());
906         result.addAttribute("simd_modifier", attr);
907       }
908     }
909     if (scheduleChunkSize) {
910       auto chunkSizeType = parser.getBuilder().getI32Type();
911       parser.resolveOperand(*scheduleChunkSize, chunkSizeType, result.operands);
912     }
913   }
914 
915   segments.insert(segments.end(), clauseSegments.begin(), clauseSegments.end());
916 
917   return success();
918 }
919 
920 /// Parses a parallel operation.
921 ///
922 /// operation ::= `omp.parallel` clause-list
923 /// clause-list ::= clause | clause clause-list
924 /// clause ::= if | num-threads | private | firstprivate | shared | copyin |
925 ///            allocate | default | proc-bind
926 ///
927 static ParseResult parseParallelOp(OpAsmParser &parser,
928                                    OperationState &result) {
929   SmallVector<ClauseType> clauses = {
930       ifClause,           numThreadsClause, privateClause,
931       firstprivateClause, sharedClause,     copyinClause,
932       allocateClause,     defaultClause,    procBindClause};
933 
934   SmallVector<int> segments;
935 
936   if (failed(parseClauses(parser, result, clauses, segments)))
937     return failure();
938 
939   result.addAttribute("operand_segment_sizes",
940                       parser.getBuilder().getI32VectorAttr(segments));
941 
942   Region *body = result.addRegion();
943   SmallVector<OpAsmParser::OperandType> regionArgs;
944   SmallVector<Type> regionArgTypes;
945   if (parser.parseRegion(*body, regionArgs, regionArgTypes))
946     return failure();
947   return success();
948 }
949 
950 //===----------------------------------------------------------------------===//
951 // Parser, printer and verifier for SectionsOp
952 //===----------------------------------------------------------------------===//
953 
954 /// Parses an OpenMP Sections operation
955 ///
956 /// sections ::= `omp.sections` clause-list
957 /// clause-list ::= clause clause-list | empty
958 /// clause ::= private | firstprivate | lastprivate | reduction | allocate |
959 ///            nowait
960 static ParseResult parseSectionsOp(OpAsmParser &parser,
961                                    OperationState &result) {
962 
963   SmallVector<ClauseType> clauses = {privateClause,     firstprivateClause,
964                                      lastprivateClause, reductionClause,
965                                      allocateClause,    nowaitClause};
966 
967   SmallVector<int> segments;
968 
969   if (failed(parseClauses(parser, result, clauses, segments)))
970     return failure();
971 
972   result.addAttribute("operand_segment_sizes",
973                       parser.getBuilder().getI32VectorAttr(segments));
974 
975   // Now parse the body.
976   Region *body = result.addRegion();
977   if (parser.parseRegion(*body))
978     return failure();
979   return success();
980 }
981 
982 static void printSectionsOp(OpAsmPrinter &p, SectionsOp op) {
983   p << " ";
984   printDataVars(p, op.private_vars(), "private");
985   printDataVars(p, op.firstprivate_vars(), "firstprivate");
986   printDataVars(p, op.lastprivate_vars(), "lastprivate");
987 
988   if (!op.reduction_vars().empty())
989     printReductionVarList(p, op.reductions(), op.reduction_vars());
990 
991   if (!op.allocate_vars().empty())
992     printAllocateAndAllocator(p, op.allocate_vars(), op.allocators_vars());
993 
994   if (op.nowait())
995     p << "nowait";
996 
997   p << ' ';
998   p.printRegion(op.region());
999 }
1000 
1001 static LogicalResult verifySectionsOp(SectionsOp op) {
1002 
1003   // A list item may not appear in more than one clause on the same directive,
1004   // except that it may be specified in both firstprivate and lastprivate
1005   // clauses.
1006   for (auto var : op.private_vars()) {
1007     if (llvm::is_contained(op.firstprivate_vars(), var))
1008       return op.emitOpError()
1009              << "operand used in both private and firstprivate clauses";
1010     if (llvm::is_contained(op.lastprivate_vars(), var))
1011       return op.emitOpError()
1012              << "operand used in both private and lastprivate clauses";
1013   }
1014 
1015   if (op.allocate_vars().size() != op.allocators_vars().size())
1016     return op.emitError(
1017         "expected equal sizes for allocate and allocator variables");
1018 
1019   for (auto &inst : *op.region().begin()) {
1020     if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst)))
1021       op.emitOpError()
1022           << "expected omp.section op or terminator op inside region";
1023   }
1024 
1025   return verifyReductionVarList(op, op.reductions(), op.reduction_vars());
1026 }
1027 
1028 /// Parses an OpenMP Workshare Loop operation
1029 ///
1030 /// wsloop ::= `omp.wsloop` loop-control clause-list
1031 /// loop-control ::= `(` ssa-id-list `)` `:` type `=`  loop-bounds
1032 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` inclusive? steps
1033 /// steps := `step` `(`ssa-id-list`)`
1034 /// clause-list ::= clause clause-list | empty
1035 /// clause ::= private | firstprivate | lastprivate | linear | schedule |
1036 //             collapse | nowait | ordered | order | reduction
1037 static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) {
1038 
1039   // Parse an opening `(` followed by induction variables followed by `)`
1040   SmallVector<OpAsmParser::OperandType> ivs;
1041   if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
1042                                      OpAsmParser::Delimiter::Paren))
1043     return failure();
1044 
1045   int numIVs = static_cast<int>(ivs.size());
1046   Type loopVarType;
1047   if (parser.parseColonType(loopVarType))
1048     return failure();
1049 
1050   // Parse loop bounds.
1051   SmallVector<OpAsmParser::OperandType> lower;
1052   if (parser.parseEqual() ||
1053       parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) ||
1054       parser.resolveOperands(lower, loopVarType, result.operands))
1055     return failure();
1056 
1057   SmallVector<OpAsmParser::OperandType> upper;
1058   if (parser.parseKeyword("to") ||
1059       parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) ||
1060       parser.resolveOperands(upper, loopVarType, result.operands))
1061     return failure();
1062 
1063   if (succeeded(parser.parseOptionalKeyword("inclusive"))) {
1064     auto attr = UnitAttr::get(parser.getBuilder().getContext());
1065     result.addAttribute("inclusive", attr);
1066   }
1067 
1068   // Parse step values.
1069   SmallVector<OpAsmParser::OperandType> steps;
1070   if (parser.parseKeyword("step") ||
1071       parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) ||
1072       parser.resolveOperands(steps, loopVarType, result.operands))
1073     return failure();
1074 
1075   SmallVector<ClauseType> clauses = {
1076       privateClause,   firstprivateClause, lastprivateClause, linearClause,
1077       reductionClause, collapseClause,     orderClause,       orderedClause,
1078       nowaitClause,    scheduleClause};
1079   SmallVector<int> segments{numIVs, numIVs, numIVs};
1080   if (failed(parseClauses(parser, result, clauses, segments)))
1081     return failure();
1082 
1083   result.addAttribute("operand_segment_sizes",
1084                       parser.getBuilder().getI32VectorAttr(segments));
1085 
1086   // Now parse the body.
1087   Region *body = result.addRegion();
1088   SmallVector<Type> ivTypes(numIVs, loopVarType);
1089   SmallVector<OpAsmParser::OperandType> blockArgs(ivs);
1090   if (parser.parseRegion(*body, blockArgs, ivTypes))
1091     return failure();
1092   return success();
1093 }
1094 
1095 static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) {
1096   auto args = op.getRegion().front().getArguments();
1097   p << " (" << args << ") : " << args[0].getType() << " = (" << op.lowerBound()
1098     << ") to (" << op.upperBound() << ") ";
1099   if (op.inclusive()) {
1100     p << "inclusive ";
1101   }
1102   p << "step (" << op.step() << ") ";
1103 
1104   printDataVars(p, op.private_vars(), "private");
1105   printDataVars(p, op.firstprivate_vars(), "firstprivate");
1106   printDataVars(p, op.lastprivate_vars(), "lastprivate");
1107 
1108   if (!op.linear_vars().empty())
1109     printLinearClause(p, op.linear_vars(), op.linear_step_vars());
1110 
1111   if (auto sched = op.schedule_val())
1112     printScheduleClause(p, sched.getValue(), op.schedule_modifier(),
1113                         op.simd_modifier(), op.schedule_chunk_var());
1114 
1115   if (auto collapse = op.collapse_val())
1116     p << "collapse(" << collapse << ") ";
1117 
1118   if (op.nowait())
1119     p << "nowait ";
1120 
1121   if (auto ordered = op.ordered_val())
1122     p << "ordered(" << ordered << ") ";
1123 
1124   if (auto order = op.order_val())
1125     p << "order(" << stringifyClauseOrderKind(*order) << ") ";
1126 
1127   if (!op.reduction_vars().empty())
1128     printReductionVarList(p, op.reductions(), op.reduction_vars());
1129 
1130   p << ' ';
1131   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
1132 }
1133 
1134 //===----------------------------------------------------------------------===//
1135 // ReductionOp
1136 //===----------------------------------------------------------------------===//
1137 
1138 static ParseResult parseAtomicReductionRegion(OpAsmParser &parser,
1139                                               Region &region) {
1140   if (parser.parseOptionalKeyword("atomic"))
1141     return success();
1142   return parser.parseRegion(region);
1143 }
1144 
1145 static void printAtomicReductionRegion(OpAsmPrinter &printer,
1146                                        ReductionDeclareOp op, Region &region) {
1147   if (region.empty())
1148     return;
1149   printer << "atomic ";
1150   printer.printRegion(region);
1151 }
1152 
1153 static LogicalResult verifyReductionDeclareOp(ReductionDeclareOp op) {
1154   if (op.initializerRegion().empty())
1155     return op.emitOpError() << "expects non-empty initializer region";
1156   Block &initializerEntryBlock = op.initializerRegion().front();
1157   if (initializerEntryBlock.getNumArguments() != 1 ||
1158       initializerEntryBlock.getArgument(0).getType() != op.type()) {
1159     return op.emitOpError() << "expects initializer region with one argument "
1160                                "of the reduction type";
1161   }
1162 
1163   for (YieldOp yieldOp : op.initializerRegion().getOps<YieldOp>()) {
1164     if (yieldOp.results().size() != 1 ||
1165         yieldOp.results().getTypes()[0] != op.type())
1166       return op.emitOpError() << "expects initializer region to yield a value "
1167                                  "of the reduction type";
1168   }
1169 
1170   if (op.reductionRegion().empty())
1171     return op.emitOpError() << "expects non-empty reduction region";
1172   Block &reductionEntryBlock = op.reductionRegion().front();
1173   if (reductionEntryBlock.getNumArguments() != 2 ||
1174       reductionEntryBlock.getArgumentTypes()[0] !=
1175           reductionEntryBlock.getArgumentTypes()[1] ||
1176       reductionEntryBlock.getArgumentTypes()[0] != op.type())
1177     return op.emitOpError() << "expects reduction region with two arguments of "
1178                                "the reduction type";
1179   for (YieldOp yieldOp : op.reductionRegion().getOps<YieldOp>()) {
1180     if (yieldOp.results().size() != 1 ||
1181         yieldOp.results().getTypes()[0] != op.type())
1182       return op.emitOpError() << "expects reduction region to yield a value "
1183                                  "of the reduction type";
1184   }
1185 
1186   if (op.atomicReductionRegion().empty())
1187     return success();
1188 
1189   Block &atomicReductionEntryBlock = op.atomicReductionRegion().front();
1190   if (atomicReductionEntryBlock.getNumArguments() != 2 ||
1191       atomicReductionEntryBlock.getArgumentTypes()[0] !=
1192           atomicReductionEntryBlock.getArgumentTypes()[1])
1193     return op.emitOpError() << "expects atomic reduction region with two "
1194                                "arguments of the same type";
1195   auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0]
1196                      .dyn_cast<PointerLikeType>();
1197   if (!ptrType || ptrType.getElementType() != op.type())
1198     return op.emitOpError() << "expects atomic reduction region arguments to "
1199                                "be accumulators containing the reduction type";
1200   return success();
1201 }
1202 
1203 static LogicalResult verifyReductionOp(ReductionOp op) {
1204   // TODO: generalize this to an op interface when there is more than one op
1205   // that supports reductions.
1206   auto container = op->getParentOfType<WsLoopOp>();
1207   for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i)
1208     if (container.reduction_vars()[i] == op.accumulator())
1209       return success();
1210 
1211   return op.emitOpError() << "the accumulator is not used by the parent";
1212 }
1213 
1214 //===----------------------------------------------------------------------===//
1215 // WsLoopOp
1216 //===----------------------------------------------------------------------===//
1217 
1218 void WsLoopOp::build(OpBuilder &builder, OperationState &state,
1219                      ValueRange lowerBound, ValueRange upperBound,
1220                      ValueRange step, ArrayRef<NamedAttribute> attributes) {
1221   build(builder, state, TypeRange(), lowerBound, upperBound, step,
1222         /*privateVars=*/ValueRange(),
1223         /*firstprivateVars=*/ValueRange(), /*lastprivate_vars=*/ValueRange(),
1224         /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
1225         /*reduction_vars=*/ValueRange(), /*schedule_val=*/nullptr,
1226         /*schedule_chunk_var=*/nullptr, /*collapse_val=*/nullptr,
1227         /*nowait=*/nullptr, /*ordered_val=*/nullptr, /*order_val=*/nullptr,
1228         /*inclusive=*/nullptr, /*buildBody=*/false);
1229   state.addAttributes(attributes);
1230 }
1231 
1232 void WsLoopOp::build(OpBuilder &, OperationState &state, TypeRange resultTypes,
1233                      ValueRange operands, ArrayRef<NamedAttribute> attributes) {
1234   state.addOperands(operands);
1235   state.addAttributes(attributes);
1236   (void)state.addRegion();
1237   assert(resultTypes.empty() && "mismatched number of return types");
1238   state.addTypes(resultTypes);
1239 }
1240 
1241 void WsLoopOp::build(OpBuilder &builder, OperationState &result,
1242                      TypeRange typeRange, ValueRange lowerBounds,
1243                      ValueRange upperBounds, ValueRange steps,
1244                      ValueRange privateVars, ValueRange firstprivateVars,
1245                      ValueRange lastprivateVars, ValueRange linearVars,
1246                      ValueRange linearStepVars, ValueRange reductionVars,
1247                      StringAttr scheduleVal, Value scheduleChunkVar,
1248                      IntegerAttr collapseVal, UnitAttr nowait,
1249                      IntegerAttr orderedVal, StringAttr orderVal,
1250                      UnitAttr inclusive, bool buildBody) {
1251   result.addOperands(lowerBounds);
1252   result.addOperands(upperBounds);
1253   result.addOperands(steps);
1254   result.addOperands(privateVars);
1255   result.addOperands(firstprivateVars);
1256   result.addOperands(linearVars);
1257   result.addOperands(linearStepVars);
1258   if (scheduleChunkVar)
1259     result.addOperands(scheduleChunkVar);
1260 
1261   if (scheduleVal)
1262     result.addAttribute("schedule_val", scheduleVal);
1263   if (collapseVal)
1264     result.addAttribute("collapse_val", collapseVal);
1265   if (nowait)
1266     result.addAttribute("nowait", nowait);
1267   if (orderedVal)
1268     result.addAttribute("ordered_val", orderedVal);
1269   if (orderVal)
1270     result.addAttribute("order", orderVal);
1271   if (inclusive)
1272     result.addAttribute("inclusive", inclusive);
1273   result.addAttribute(
1274       WsLoopOp::getOperandSegmentSizeAttr(),
1275       builder.getI32VectorAttr(
1276           {static_cast<int32_t>(lowerBounds.size()),
1277            static_cast<int32_t>(upperBounds.size()),
1278            static_cast<int32_t>(steps.size()),
1279            static_cast<int32_t>(privateVars.size()),
1280            static_cast<int32_t>(firstprivateVars.size()),
1281            static_cast<int32_t>(lastprivateVars.size()),
1282            static_cast<int32_t>(linearVars.size()),
1283            static_cast<int32_t>(linearStepVars.size()),
1284            static_cast<int32_t>(reductionVars.size()),
1285            static_cast<int32_t>(scheduleChunkVar != nullptr ? 1 : 0)}));
1286 
1287   Region *bodyRegion = result.addRegion();
1288   if (buildBody) {
1289     OpBuilder::InsertionGuard guard(builder);
1290     unsigned numIVs = steps.size();
1291     SmallVector<Type, 8> argTypes(numIVs, steps.getType().front());
1292     builder.createBlock(bodyRegion, {}, argTypes);
1293   }
1294 }
1295 
1296 static LogicalResult verifyWsLoopOp(WsLoopOp op) {
1297   return verifyReductionVarList(op, op.reductions(), op.reduction_vars());
1298 }
1299 
1300 //===----------------------------------------------------------------------===//
1301 // Verifier for critical construct (2.17.1)
1302 //===----------------------------------------------------------------------===//
1303 
1304 static LogicalResult verifyCriticalDeclareOp(CriticalDeclareOp op) {
1305   return verifySynchronizationHint(op, op.hint());
1306 }
1307 
1308 static LogicalResult verifyCriticalOp(CriticalOp op) {
1309 
1310   if (op.nameAttr()) {
1311     auto symbolRef = op.nameAttr().cast<SymbolRefAttr>();
1312     auto decl =
1313         SymbolTable::lookupNearestSymbolFrom<CriticalDeclareOp>(op, symbolRef);
1314     if (!decl) {
1315       return op.emitOpError() << "expected symbol reference " << symbolRef
1316                               << " to point to a critical declaration";
1317     }
1318   }
1319 
1320   return success();
1321 }
1322 
1323 //===----------------------------------------------------------------------===//
1324 // Verifier for ordered construct
1325 //===----------------------------------------------------------------------===//
1326 
1327 static LogicalResult verifyOrderedOp(OrderedOp op) {
1328   auto container = op->getParentOfType<WsLoopOp>();
1329   if (!container || !container.ordered_valAttr() ||
1330       container.ordered_valAttr().getInt() == 0)
1331     return op.emitOpError() << "ordered depend directive must be closely "
1332                             << "nested inside a worksharing-loop with ordered "
1333                             << "clause with parameter present";
1334 
1335   if (container.ordered_valAttr().getInt() !=
1336       (int64_t)op.num_loops_val().getValue())
1337     return op.emitOpError() << "number of variables in depend clause does not "
1338                             << "match number of iteration variables in the "
1339                             << "doacross loop";
1340 
1341   return success();
1342 }
1343 
1344 static LogicalResult verifyOrderedRegionOp(OrderedRegionOp op) {
1345   // TODO: The code generation for ordered simd directive is not supported yet.
1346   if (op.simd())
1347     return failure();
1348 
1349   if (auto container = op->getParentOfType<WsLoopOp>()) {
1350     if (!container.ordered_valAttr() ||
1351         container.ordered_valAttr().getInt() != 0)
1352       return op.emitOpError() << "ordered region must be closely nested inside "
1353                               << "a worksharing-loop region with an ordered "
1354                               << "clause without parameter present";
1355   }
1356 
1357   return success();
1358 }
1359 
1360 //===----------------------------------------------------------------------===//
1361 // AtomicReadOp
1362 //===----------------------------------------------------------------------===//
1363 
1364 /// Parser for AtomicReadOp
1365 ///
1366 /// operation ::= `omp.atomic.read` atomic-clause-list address `->` result-type
1367 /// address ::= operand `:` type
1368 static ParseResult parseAtomicReadOp(OpAsmParser &parser,
1369                                      OperationState &result) {
1370   OpAsmParser::OperandType x, v;
1371   Type addressType;
1372   SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
1373   SmallVector<int> segments;
1374 
1375   if (parser.parseOperand(v) || parser.parseEqual() || parser.parseOperand(x) ||
1376       parseClauses(parser, result, clauses, segments) ||
1377       parser.parseColonType(addressType) ||
1378       parser.resolveOperand(x, addressType, result.operands) ||
1379       parser.resolveOperand(v, addressType, result.operands))
1380     return failure();
1381 
1382   return success();
1383 }
1384 
1385 /// Printer for AtomicReadOp
1386 static void printAtomicReadOp(OpAsmPrinter &p, AtomicReadOp op) {
1387   p << " " << op.v() << " = " << op.x() << " ";
1388   if (auto mo = op.memory_order())
1389     p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") ";
1390   if (op.hintAttr())
1391     printSynchronizationHint(p << " ", op, op.hintAttr());
1392   p << ": " << op.x().getType();
1393 }
1394 
1395 /// Verifier for AtomicReadOp
1396 static LogicalResult verifyAtomicReadOp(AtomicReadOp op) {
1397   if (auto mo = op.memory_order()) {
1398     if (*mo == ClauseMemoryOrderKind::acq_rel ||
1399         *mo == ClauseMemoryOrderKind::release) {
1400       return op.emitError(
1401           "memory-order must not be acq_rel or release for atomic reads");
1402     }
1403   }
1404   if (op.x() == op.v())
1405     return op.emitError(
1406         "read and write must not be to the same location for atomic reads");
1407   return verifySynchronizationHint(op, op.hint());
1408 }
1409 
1410 //===----------------------------------------------------------------------===//
1411 // AtomicWriteOp
1412 //===----------------------------------------------------------------------===//
1413 
1414 /// Parser for AtomicWriteOp
1415 ///
1416 /// operation ::= `omp.atomic.write` atomic-clause-list operands
1417 /// operands ::= address `,` value
1418 /// address ::= operand `:` type
1419 /// value ::= operand `:` type
1420 static ParseResult parseAtomicWriteOp(OpAsmParser &parser,
1421                                       OperationState &result) {
1422   OpAsmParser::OperandType address, value;
1423   Type addrType, valueType;
1424   SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
1425   SmallVector<int> segments;
1426 
1427   if (parser.parseOperand(address) || parser.parseEqual() ||
1428       parser.parseOperand(value) ||
1429       parseClauses(parser, result, clauses, segments) ||
1430       parser.parseColonType(addrType) || parser.parseComma() ||
1431       parser.parseType(valueType) ||
1432       parser.resolveOperand(address, addrType, result.operands) ||
1433       parser.resolveOperand(value, valueType, result.operands))
1434     return failure();
1435   return success();
1436 }
1437 
1438 /// Printer for AtomicWriteOp
1439 static void printAtomicWriteOp(OpAsmPrinter &p, AtomicWriteOp op) {
1440   p << " " << op.address() << " = " << op.value() << " ";
1441   if (auto mo = op.memory_order())
1442     p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") ";
1443   if (op.hintAttr())
1444     printSynchronizationHint(p, op, op.hintAttr());
1445   p << ": " << op.address().getType() << ", " << op.value().getType();
1446 }
1447 
1448 /// Verifier for AtomicWriteOp
1449 static LogicalResult verifyAtomicWriteOp(AtomicWriteOp op) {
1450   if (auto mo = op.memory_order()) {
1451     if (*mo == ClauseMemoryOrderKind::acq_rel ||
1452         *mo == ClauseMemoryOrderKind::acquire) {
1453       return op.emitError(
1454           "memory-order must not be acq_rel or acquire for atomic writes");
1455     }
1456   }
1457   return verifySynchronizationHint(op, op.hint());
1458 }
1459 
1460 //===----------------------------------------------------------------------===//
1461 // AtomicUpdateOp
1462 //===----------------------------------------------------------------------===//
1463 
1464 /// Parser for AtomicUpdateOp
1465 ///
1466 /// operation ::= `omp.atomic.update` atomic-clause-list region
1467 static ParseResult parseAtomicUpdateOp(OpAsmParser &parser,
1468                                        OperationState &result) {
1469   SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
1470   SmallVector<int> segments;
1471   OpAsmParser::OperandType x, y, z;
1472   Type xType, exprType;
1473   StringRef binOp;
1474 
1475   // x = y `op` z : xtype, exprtype
1476   if (parser.parseOperand(x) || parser.parseEqual() || parser.parseOperand(y) ||
1477       parser.parseKeyword(&binOp) || parser.parseOperand(z) ||
1478       parseClauses(parser, result, clauses, segments) || parser.parseColon() ||
1479       parser.parseType(xType) || parser.parseComma() ||
1480       parser.parseType(exprType) ||
1481       parser.resolveOperand(x, xType, result.operands)) {
1482     return failure();
1483   }
1484 
1485   auto binOpEnum = AtomicBinOpKindToEnum(binOp.upper());
1486   if (!binOpEnum)
1487     return parser.emitError(parser.getNameLoc())
1488            << "invalid atomic bin op in atomic update\n";
1489   auto attr =
1490       parser.getBuilder().getI64IntegerAttr((int64_t)binOpEnum.getValue());
1491   result.addAttribute("binop", attr);
1492 
1493   OpAsmParser::OperandType expr;
1494   if (x.name == y.name && x.number == y.number) {
1495     expr = z;
1496     result.addAttribute("isXBinopExpr", parser.getBuilder().getUnitAttr());
1497   } else if (x.name == z.name && x.number == z.number) {
1498     expr = y;
1499   } else {
1500     return parser.emitError(parser.getNameLoc())
1501            << "atomic update variable " << x.name
1502            << " not found in the RHS of the assignment statement in an"
1503               " atomic.update operation";
1504   }
1505   return parser.resolveOperand(expr, exprType, result.operands);
1506 }
1507 
1508 /// Printer for AtomicUpdateOp
1509 static void printAtomicUpdateOp(OpAsmPrinter &p, AtomicUpdateOp op) {
1510   p << " " << op.x() << " = ";
1511   Value y, z;
1512   if (op.isXBinopExpr()) {
1513     y = op.x();
1514     z = op.expr();
1515   } else {
1516     y = op.expr();
1517     z = op.x();
1518   }
1519   p << y << " " << AtomicBinOpKindToString(op.binop()).lower() << " " << z
1520     << " ";
1521   if (auto mo = op.memory_order())
1522     p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") ";
1523   if (op.hintAttr())
1524     printSynchronizationHint(p, op, op.hintAttr());
1525   p << ": " << op.x().getType() << ", " << op.expr().getType();
1526 }
1527 
1528 /// Verifier for AtomicUpdateOp
1529 static LogicalResult verifyAtomicUpdateOp(AtomicUpdateOp op) {
1530   if (auto mo = op.memory_order()) {
1531     if (*mo == ClauseMemoryOrderKind::acq_rel ||
1532         *mo == ClauseMemoryOrderKind::acquire) {
1533       return op.emitError(
1534           "memory-order must not be acq_rel or acquire for atomic updates");
1535     }
1536   }
1537   return success();
1538 }
1539 
1540 #define GET_ATTRDEF_CLASSES
1541 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
1542 
1543 #define GET_OP_CLASSES
1544 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
1545