1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the OpenMP dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
15 #include "mlir/IR/Attributes.h"
16 #include "mlir/IR/DialectImplementation.h"
17 #include "mlir/IR/OpImplementation.h"
18 #include "mlir/IR/OperationSupport.h"
19 
20 #include "llvm/ADT/BitVector.h"
21 #include "llvm/ADT/SmallString.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/ADT/StringSwitch.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include <cstddef>
27 
28 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
29 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
30 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
31 
32 using namespace mlir;
33 using namespace mlir::omp;
34 
35 namespace {
36 /// Model for pointer-like types that already provide a `getElementType` method.
37 template <typename T>
38 struct PointerLikeModel
39     : public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> {
40   Type getElementType(Type pointer) const {
41     return pointer.cast<T>().getElementType();
42   }
43 };
44 } // namespace
45 
46 void OpenMPDialect::initialize() {
47   addOperations<
48 #define GET_OP_LIST
49 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
50       >();
51   addAttributes<
52 #define GET_ATTRDEF_LIST
53 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
54       >();
55 
56   LLVM::LLVMPointerType::attachInterface<
57       PointerLikeModel<LLVM::LLVMPointerType>>(*getContext());
58   MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext());
59 }
60 
61 //===----------------------------------------------------------------------===//
62 // ParallelOp
63 //===----------------------------------------------------------------------===//
64 
65 void ParallelOp::build(OpBuilder &builder, OperationState &state,
66                        ArrayRef<NamedAttribute> attributes) {
67   ParallelOp::build(
68       builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
69       /*default_val=*/nullptr, /*private_vars=*/ValueRange(),
70       /*firstprivate_vars=*/ValueRange(), /*shared_vars=*/ValueRange(),
71       /*copyin_vars=*/ValueRange(), /*allocate_vars=*/ValueRange(),
72       /*allocators_vars=*/ValueRange(), /*proc_bind_val=*/nullptr);
73   state.addAttributes(attributes);
74 }
75 
76 //===----------------------------------------------------------------------===//
77 // Parser and printer for Operand and type list
78 //===----------------------------------------------------------------------===//
79 
80 /// Parse a list of operands with types.
81 ///
82 /// operand-and-type-list ::= `(` ssa-id-and-type-list `)`
83 /// ssa-id-and-type-list ::= ssa-id-and-type |
84 ///                          ssa-id-and-type `,` ssa-id-and-type-list
85 /// ssa-id-and-type ::= ssa-id `:` type
86 static ParseResult
87 parseOperandAndTypeList(OpAsmParser &parser,
88                         SmallVectorImpl<OpAsmParser::OperandType> &operands,
89                         SmallVectorImpl<Type> &types) {
90   return parser.parseCommaSeparatedList(
91       OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
92         OpAsmParser::OperandType operand;
93         Type type;
94         if (parser.parseOperand(operand) || parser.parseColonType(type))
95           return failure();
96         operands.push_back(operand);
97         types.push_back(type);
98         return success();
99       });
100 }
101 
102 /// Print an operand and type list with parentheses
103 static void printOperandAndTypeList(OpAsmPrinter &p, OperandRange operands) {
104   p << "(";
105   llvm::interleaveComma(
106       operands, p, [&](const Value &v) { p << v << " : " << v.getType(); });
107   p << ") ";
108 }
109 
110 /// Print data variables corresponding to a data-sharing clause `name`
111 static void printDataVars(OpAsmPrinter &p, OperandRange operands,
112                           StringRef name) {
113   if (!operands.empty()) {
114     p << name;
115     printOperandAndTypeList(p, operands);
116   }
117 }
118 
119 //===----------------------------------------------------------------------===//
120 // Parser and printer for Allocate Clause
121 //===----------------------------------------------------------------------===//
122 
123 /// Parse an allocate clause with allocators and a list of operands with types.
124 ///
125 /// allocate ::= `allocate` `(` allocate-operand-list `)`
126 /// allocate-operand-list :: = allocate-operand |
127 ///                            allocator-operand `,` allocate-operand-list
128 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
129 /// ssa-id-and-type ::= ssa-id `:` type
130 static ParseResult parseAllocateAndAllocator(
131     OpAsmParser &parser,
132     SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate,
133     SmallVectorImpl<Type> &typesAllocate,
134     SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator,
135     SmallVectorImpl<Type> &typesAllocator) {
136 
137   return parser.parseCommaSeparatedList(
138       OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
139         OpAsmParser::OperandType operand;
140         Type type;
141         if (parser.parseOperand(operand) || parser.parseColonType(type))
142           return failure();
143         operandsAllocator.push_back(operand);
144         typesAllocator.push_back(type);
145         if (parser.parseArrow())
146           return failure();
147         if (parser.parseOperand(operand) || parser.parseColonType(type))
148           return failure();
149 
150         operandsAllocate.push_back(operand);
151         typesAllocate.push_back(type);
152         return success();
153       });
154 }
155 
156 /// Print allocate clause
157 static void printAllocateAndAllocator(OpAsmPrinter &p,
158                                       OperandRange varsAllocate,
159                                       OperandRange varsAllocator) {
160   p << "allocate(";
161   for (unsigned i = 0; i < varsAllocate.size(); ++i) {
162     std::string separator = i == varsAllocate.size() - 1 ? ") " : ", ";
163     p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> ";
164     p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator;
165   }
166 }
167 
168 static LogicalResult verifyParallelOp(ParallelOp op) {
169   if (op.allocate_vars().size() != op.allocators_vars().size())
170     return op.emitError(
171         "expected equal sizes for allocate and allocator variables");
172   return success();
173 }
174 
175 static void printParallelOp(OpAsmPrinter &p, ParallelOp op) {
176   p << " ";
177   if (auto ifCond = op.if_expr_var())
178     p << "if(" << ifCond << " : " << ifCond.getType() << ") ";
179 
180   if (auto threads = op.num_threads_var())
181     p << "num_threads(" << threads << " : " << threads.getType() << ") ";
182 
183   printDataVars(p, op.private_vars(), "private");
184   printDataVars(p, op.firstprivate_vars(), "firstprivate");
185   printDataVars(p, op.shared_vars(), "shared");
186   printDataVars(p, op.copyin_vars(), "copyin");
187 
188   if (!op.allocate_vars().empty())
189     printAllocateAndAllocator(p, op.allocate_vars(), op.allocators_vars());
190 
191   if (auto def = op.default_val())
192     p << "default(" << stringifyClauseDefault(*def).drop_front(3) << ") ";
193 
194   if (auto bind = op.proc_bind_val())
195     p << "proc_bind(" << stringifyClauseProcBindKind(*bind) << ") ";
196 
197   p << ' ';
198   p.printRegion(op.getRegion());
199 }
200 
201 static void printTargetOp(OpAsmPrinter &p, TargetOp op) {
202   p << " ";
203   if (auto ifCond = op.if_expr())
204     p << "if(" << ifCond << " : " << ifCond.getType() << ") ";
205 
206   if (auto device = op.device())
207     p << "device(" << device << " : " << device.getType() << ") ";
208 
209   if (auto threads = op.thread_limit())
210     p << "thread_limit(" << threads << " : " << threads.getType() << ") ";
211 
212   if (op.nowait()) {
213     p << "nowait ";
214   }
215 
216   p.printRegion(op.getRegion());
217 }
218 
219 //===----------------------------------------------------------------------===//
220 // Parser and printer for Linear Clause
221 //===----------------------------------------------------------------------===//
222 
223 /// linear ::= `linear` `(` linear-list `)`
224 /// linear-list := linear-val | linear-val linear-list
225 /// linear-val := ssa-id-and-type `=` ssa-id-and-type
226 static ParseResult
227 parseLinearClause(OpAsmParser &parser,
228                   SmallVectorImpl<OpAsmParser::OperandType> &vars,
229                   SmallVectorImpl<Type> &types,
230                   SmallVectorImpl<OpAsmParser::OperandType> &stepVars) {
231   if (parser.parseLParen())
232     return failure();
233 
234   do {
235     OpAsmParser::OperandType var;
236     Type type;
237     OpAsmParser::OperandType stepVar;
238     if (parser.parseOperand(var) || parser.parseEqual() ||
239         parser.parseOperand(stepVar) || parser.parseColonType(type))
240       return failure();
241 
242     vars.push_back(var);
243     types.push_back(type);
244     stepVars.push_back(stepVar);
245   } while (succeeded(parser.parseOptionalComma()));
246 
247   if (parser.parseRParen())
248     return failure();
249 
250   return success();
251 }
252 
253 /// Print Linear Clause
254 static void printLinearClause(OpAsmPrinter &p, OperandRange linearVars,
255                               OperandRange linearStepVars) {
256   size_t linearVarsSize = linearVars.size();
257   p << "linear(";
258   for (unsigned i = 0; i < linearVarsSize; ++i) {
259     std::string separator = i == linearVarsSize - 1 ? ") " : ", ";
260     p << linearVars[i];
261     if (linearStepVars.size() > i)
262       p << " = " << linearStepVars[i];
263     p << " : " << linearVars[i].getType() << separator;
264   }
265 }
266 
267 //===----------------------------------------------------------------------===//
268 // Parser and printer for Schedule Clause
269 //===----------------------------------------------------------------------===//
270 
271 static ParseResult
272 verifyScheduleModifiers(OpAsmParser &parser,
273                         SmallVectorImpl<SmallString<12>> &modifiers) {
274   if (modifiers.size() > 2)
275     return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)";
276   for (const auto &mod : modifiers) {
277     // Translate the string. If it has no value, then it was not a valid
278     // modifier!
279     auto symbol = symbolizeScheduleModifier(mod);
280     if (!symbol.hasValue())
281       return parser.emitError(parser.getNameLoc())
282              << " unknown modifier type: " << mod;
283   }
284 
285   // If we have one modifier that is "simd", then stick a "none" modiifer in
286   // index 0.
287   if (modifiers.size() == 1) {
288     if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
289       modifiers.push_back(modifiers[0]);
290       modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
291     }
292   } else if (modifiers.size() == 2) {
293     // If there are two modifier:
294     // First modifier should not be simd, second one should be simd
295     if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
296         symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
297       return parser.emitError(parser.getNameLoc())
298              << " incorrect modifier order";
299   }
300   return success();
301 }
302 
303 /// schedule ::= `schedule` `(` sched-list `)`
304 /// sched-list ::= sched-val | sched-val sched-list |
305 ///                sched-val `,` sched-modifier
306 /// sched-val ::= sched-with-chunk | sched-wo-chunk
307 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
308 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
309 /// sched-wo-chunk ::=  `auto` | `runtime`
310 /// sched-modifier ::=  sched-mod-val | sched-mod-val `,` sched-mod-val
311 /// sched-mod-val ::=  `monotonic` | `nonmonotonic` | `simd` | `none`
312 static ParseResult
313 parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule,
314                     SmallVectorImpl<SmallString<12>> &modifiers,
315                     Optional<OpAsmParser::OperandType> &chunkSize,
316                     Type &chunkType) {
317   if (parser.parseLParen())
318     return failure();
319 
320   StringRef keyword;
321   if (parser.parseKeyword(&keyword))
322     return failure();
323 
324   schedule = keyword;
325   if (keyword == "static" || keyword == "dynamic" || keyword == "guided") {
326     if (succeeded(parser.parseOptionalEqual())) {
327       chunkSize = OpAsmParser::OperandType{};
328       if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType))
329         return failure();
330     } else {
331       chunkSize = llvm::NoneType::None;
332     }
333   } else if (keyword == "auto" || keyword == "runtime") {
334     chunkSize = llvm::NoneType::None;
335   } else {
336     return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
337   }
338 
339   // If there is a comma, we have one or more modifiers..
340   while (succeeded(parser.parseOptionalComma())) {
341     StringRef mod;
342     if (parser.parseKeyword(&mod))
343       return failure();
344     modifiers.push_back(mod);
345   }
346 
347   if (parser.parseRParen())
348     return failure();
349 
350   if (verifyScheduleModifiers(parser, modifiers))
351     return failure();
352 
353   return success();
354 }
355 
356 /// Print schedule clause
357 static void printScheduleClause(OpAsmPrinter &p, ClauseScheduleKind sched,
358                                 Optional<ScheduleModifier> modifier, bool simd,
359                                 Value scheduleChunkVar) {
360   p << "schedule(" << stringifyClauseScheduleKind(sched).lower();
361   if (scheduleChunkVar)
362     p << " = " << scheduleChunkVar << " : " << scheduleChunkVar.getType();
363   if (modifier)
364     p << ", " << stringifyScheduleModifier(*modifier);
365   if (simd)
366     p << ", simd";
367   p << ") ";
368 }
369 
370 //===----------------------------------------------------------------------===//
371 // Parser, printer and verifier for ReductionVarList
372 //===----------------------------------------------------------------------===//
373 
374 /// reduction ::= `reduction` `(` reduction-entry-list `)`
375 /// reduction-entry-list ::= reduction-entry
376 ///                        | reduction-entry-list `,` reduction-entry
377 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type
378 static ParseResult
379 parseReductionVarList(OpAsmParser &parser,
380                       SmallVectorImpl<SymbolRefAttr> &symbols,
381                       SmallVectorImpl<OpAsmParser::OperandType> &operands,
382                       SmallVectorImpl<Type> &types) {
383   if (failed(parser.parseLParen()))
384     return failure();
385 
386   do {
387     if (parser.parseAttribute(symbols.emplace_back()) || parser.parseArrow() ||
388         parser.parseOperand(operands.emplace_back()) ||
389         parser.parseColonType(types.emplace_back()))
390       return failure();
391   } while (succeeded(parser.parseOptionalComma()));
392   return parser.parseRParen();
393 }
394 
395 /// Print Reduction clause
396 static void printReductionVarList(OpAsmPrinter &p,
397                                   Optional<ArrayAttr> reductions,
398                                   OperandRange reductionVars) {
399   p << "reduction(";
400   for (unsigned i = 0, e = reductions->size(); i < e; ++i) {
401     if (i != 0)
402       p << ", ";
403     p << (*reductions)[i] << " -> " << reductionVars[i] << " : "
404       << reductionVars[i].getType();
405   }
406   p << ") ";
407 }
408 
409 /// Verifies Reduction Clause
410 static LogicalResult verifyReductionVarList(Operation *op,
411                                             Optional<ArrayAttr> reductions,
412                                             OperandRange reductionVars) {
413   if (!reductionVars.empty()) {
414     if (!reductions || reductions->size() != reductionVars.size())
415       return op->emitOpError()
416              << "expected as many reduction symbol references "
417                 "as reduction variables";
418   } else {
419     if (reductions)
420       return op->emitOpError() << "unexpected reduction symbol references";
421     return success();
422   }
423 
424   DenseSet<Value> accumulators;
425   for (auto args : llvm::zip(reductionVars, *reductions)) {
426     Value accum = std::get<0>(args);
427 
428     if (!accumulators.insert(accum).second)
429       return op->emitOpError() << "accumulator variable used more than once";
430 
431     Type varType = accum.getType().cast<PointerLikeType>();
432     auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>();
433     auto decl =
434         SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef);
435     if (!decl)
436       return op->emitOpError() << "expected symbol reference " << symbolRef
437                                << " to point to a reduction declaration";
438 
439     if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
440       return op->emitOpError()
441              << "expected accumulator (" << varType
442              << ") to be the same type as reduction declaration ("
443              << decl.getAccumulatorType() << ")";
444   }
445 
446   return success();
447 }
448 
449 //===----------------------------------------------------------------------===//
450 // Parser, printer and verifier for Synchronization Hint (2.17.12)
451 //===----------------------------------------------------------------------===//
452 
453 /// Parses a Synchronization Hint clause. The value of hint is an integer
454 /// which is a combination of different hints from `omp_sync_hint_t`.
455 ///
456 /// hint-clause = `hint` `(` hint-value `)`
457 static ParseResult parseSynchronizationHint(OpAsmParser &parser,
458                                             IntegerAttr &hintAttr,
459                                             bool parseKeyword = true) {
460   if (parseKeyword && failed(parser.parseOptionalKeyword("hint"))) {
461     hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
462     return success();
463   }
464 
465   if (failed(parser.parseLParen()))
466     return failure();
467   StringRef hintKeyword;
468   int64_t hint = 0;
469   do {
470     if (failed(parser.parseKeyword(&hintKeyword)))
471       return failure();
472     if (hintKeyword == "uncontended")
473       hint |= 1;
474     else if (hintKeyword == "contended")
475       hint |= 2;
476     else if (hintKeyword == "nonspeculative")
477       hint |= 4;
478     else if (hintKeyword == "speculative")
479       hint |= 8;
480     else
481       return parser.emitError(parser.getCurrentLocation())
482              << hintKeyword << " is not a valid hint";
483   } while (succeeded(parser.parseOptionalComma()));
484   if (failed(parser.parseRParen()))
485     return failure();
486   hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
487   return success();
488 }
489 
490 /// Prints a Synchronization Hint clause
491 static void printSynchronizationHint(OpAsmPrinter &p, Operation *op,
492                                      IntegerAttr hintAttr) {
493   int64_t hint = hintAttr.getInt();
494 
495   if (hint == 0)
496     return;
497 
498   // Helper function to get n-th bit from the right end of `value`
499   auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
500 
501   bool uncontended = bitn(hint, 0);
502   bool contended = bitn(hint, 1);
503   bool nonspeculative = bitn(hint, 2);
504   bool speculative = bitn(hint, 3);
505 
506   SmallVector<StringRef> hints;
507   if (uncontended)
508     hints.push_back("uncontended");
509   if (contended)
510     hints.push_back("contended");
511   if (nonspeculative)
512     hints.push_back("nonspeculative");
513   if (speculative)
514     hints.push_back("speculative");
515 
516   p << "hint(";
517   llvm::interleaveComma(hints, p);
518   p << ") ";
519 }
520 
521 /// Verifies a synchronization hint clause
522 static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
523 
524   // Helper function to get n-th bit from the right end of `value`
525   auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
526 
527   bool uncontended = bitn(hint, 0);
528   bool contended = bitn(hint, 1);
529   bool nonspeculative = bitn(hint, 2);
530   bool speculative = bitn(hint, 3);
531 
532   if (uncontended && contended)
533     return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
534                                 "omp_sync_hint_contended cannot be combined";
535   if (nonspeculative && speculative)
536     return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
537                                 "omp_sync_hint_speculative cannot be combined.";
538   return success();
539 }
540 
541 enum ClauseType {
542   ifClause,
543   numThreadsClause,
544   deviceClause,
545   threadLimitClause,
546   privateClause,
547   firstprivateClause,
548   lastprivateClause,
549   sharedClause,
550   copyinClause,
551   allocateClause,
552   defaultClause,
553   procBindClause,
554   reductionClause,
555   nowaitClause,
556   linearClause,
557   scheduleClause,
558   collapseClause,
559   orderClause,
560   orderedClause,
561   memoryOrderClause,
562   hintClause,
563   COUNT
564 };
565 
566 //===----------------------------------------------------------------------===//
567 // Parser for Clause List
568 //===----------------------------------------------------------------------===//
569 
570 /// Parse a clause attribute `(` $value `)`.
571 template <typename ClauseAttr>
572 static ParseResult parseClauseAttr(AsmParser &parser, OperationState &state,
573                                    StringRef attrName, StringRef name) {
574   using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
575   StringRef enumStr;
576   SMLoc loc = parser.getCurrentLocation();
577   if (parser.parseLParen() || parser.parseKeyword(&enumStr) ||
578       parser.parseRParen())
579     return failure();
580   if (Optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
581     auto attr = ClauseAttr::get(parser.getContext(), *enumValue);
582     state.addAttribute(attrName, attr);
583     return success();
584   }
585   return parser.emitError(loc, "invalid ") << name << " kind";
586 }
587 
588 /// Parse a list of clauses. The clauses can appear in any order, but their
589 /// operand segment indices are in the same order that they are passed in the
590 /// `clauses` list. The operand segments are added over the prevSegments
591 
592 /// clause-list ::= clause clause-list | empty
593 /// clause ::= if | num-threads | private | firstprivate | lastprivate |
594 ///            shared | copyin | allocate | default | proc-bind | reduction |
595 ///            nowait | linear | schedule | collapse | order | ordered |
596 ///            inclusive
597 /// if ::= `if` `(` ssa-id-and-type `)`
598 /// num-threads ::= `num_threads` `(` ssa-id-and-type `)`
599 /// private ::= `private` operand-and-type-list
600 /// firstprivate ::= `firstprivate` operand-and-type-list
601 /// lastprivate ::= `lastprivate` operand-and-type-list
602 /// shared ::= `shared` operand-and-type-list
603 /// copyin ::= `copyin` operand-and-type-list
604 /// allocate ::= `allocate` `(` allocate-operand-list `)`
605 /// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`)
606 /// proc-bind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)`
607 /// reduction ::= `reduction` `(` reduction-entry-list `)`
608 /// nowait ::= `nowait`
609 /// linear ::= `linear` `(` linear-list `)`
610 /// schedule ::= `schedule` `(` sched-list `)`
611 /// collapse ::= `collapse` `(` ssa-id-and-type `)`
612 /// order ::= `order` `(` `concurrent` `)`
613 /// ordered ::= `ordered` `(` ssa-id-and-type `)`
614 /// inclusive ::= `inclusive`
615 ///
616 /// Note that each clause can only appear once in the clase-list.
617 static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
618                                 SmallVectorImpl<ClauseType> &clauses,
619                                 SmallVectorImpl<int> &segments) {
620 
621   // Check done[clause] to see if it has been parsed already
622   BitVector done(ClauseType::COUNT, false);
623 
624   // See pos[clause] to get position of clause in operand segments
625   SmallVector<int> pos(ClauseType::COUNT, -1);
626 
627   // Stores the last parsed clause keyword
628   StringRef clauseKeyword;
629   StringRef opName = result.name.getStringRef();
630 
631   // Containers for storing operands, types and attributes for various clauses
632   std::pair<OpAsmParser::OperandType, Type> ifCond;
633   std::pair<OpAsmParser::OperandType, Type> numThreads;
634   std::pair<OpAsmParser::OperandType, Type> device;
635   std::pair<OpAsmParser::OperandType, Type> threadLimit;
636 
637   SmallVector<OpAsmParser::OperandType> privates, firstprivates, lastprivates,
638       shareds, copyins;
639   SmallVector<Type> privateTypes, firstprivateTypes, lastprivateTypes,
640       sharedTypes, copyinTypes;
641 
642   SmallVector<OpAsmParser::OperandType> allocates, allocators;
643   SmallVector<Type> allocateTypes, allocatorTypes;
644 
645   SmallVector<SymbolRefAttr> reductionSymbols;
646   SmallVector<OpAsmParser::OperandType> reductionVars;
647   SmallVector<Type> reductionVarTypes;
648 
649   SmallVector<OpAsmParser::OperandType> linears;
650   SmallVector<Type> linearTypes;
651   SmallVector<OpAsmParser::OperandType> linearSteps;
652 
653   SmallString<8> schedule;
654   SmallVector<SmallString<12>> modifiers;
655   Optional<OpAsmParser::OperandType> scheduleChunkSize;
656   Type scheduleChunkType;
657 
658   // Compute the position of clauses in operand segments
659   int currPos = 0;
660   for (ClauseType clause : clauses) {
661 
662     // Skip the following clauses - they do not take any position in operand
663     // segments
664     if (clause == defaultClause || clause == procBindClause ||
665         clause == nowaitClause || clause == collapseClause ||
666         clause == orderClause || clause == orderedClause)
667       continue;
668 
669     pos[clause] = currPos++;
670 
671     // For the following clauses, two positions are reserved in the operand
672     // segments
673     if (clause == allocateClause || clause == linearClause)
674       currPos++;
675   }
676 
677   SmallVector<int> clauseSegments(currPos);
678 
679   // Helper function to check if a clause is allowed/repeated or not
680   auto checkAllowed = [&](ClauseType clause) -> ParseResult {
681     if (!llvm::is_contained(clauses, clause))
682       return parser.emitError(parser.getCurrentLocation())
683              << clauseKeyword << " is not a valid clause for the " << opName
684              << " operation";
685     if (done[clause])
686       return parser.emitError(parser.getCurrentLocation())
687              << "at most one " << clauseKeyword << " clause can appear on the "
688              << opName << " operation";
689     done[clause] = true;
690     return success();
691   };
692 
693   while (succeeded(parser.parseOptionalKeyword(&clauseKeyword))) {
694     if (clauseKeyword == "if") {
695       if (checkAllowed(ifClause) || parser.parseLParen() ||
696           parser.parseOperand(ifCond.first) ||
697           parser.parseColonType(ifCond.second) || parser.parseRParen())
698         return failure();
699       clauseSegments[pos[ifClause]] = 1;
700     } else if (clauseKeyword == "num_threads") {
701       if (checkAllowed(numThreadsClause) || parser.parseLParen() ||
702           parser.parseOperand(numThreads.first) ||
703           parser.parseColonType(numThreads.second) || parser.parseRParen())
704         return failure();
705       clauseSegments[pos[numThreadsClause]] = 1;
706     } else if (clauseKeyword == "device") {
707       if (checkAllowed(deviceClause) || parser.parseLParen() ||
708           parser.parseOperand(device.first) ||
709           parser.parseColonType(device.second) || parser.parseRParen())
710         return failure();
711       clauseSegments[pos[deviceClause]] = 1;
712     } else if (clauseKeyword == "thread_limit") {
713       if (checkAllowed(threadLimitClause) || parser.parseLParen() ||
714           parser.parseOperand(threadLimit.first) ||
715           parser.parseColonType(threadLimit.second) || parser.parseRParen())
716         return failure();
717       clauseSegments[pos[threadLimitClause]] = 1;
718     } else if (clauseKeyword == "private") {
719       if (checkAllowed(privateClause) ||
720           parseOperandAndTypeList(parser, privates, privateTypes))
721         return failure();
722       clauseSegments[pos[privateClause]] = privates.size();
723     } else if (clauseKeyword == "firstprivate") {
724       if (checkAllowed(firstprivateClause) ||
725           parseOperandAndTypeList(parser, firstprivates, firstprivateTypes))
726         return failure();
727       clauseSegments[pos[firstprivateClause]] = firstprivates.size();
728     } else if (clauseKeyword == "lastprivate") {
729       if (checkAllowed(lastprivateClause) ||
730           parseOperandAndTypeList(parser, lastprivates, lastprivateTypes))
731         return failure();
732       clauseSegments[pos[lastprivateClause]] = lastprivates.size();
733     } else if (clauseKeyword == "shared") {
734       if (checkAllowed(sharedClause) ||
735           parseOperandAndTypeList(parser, shareds, sharedTypes))
736         return failure();
737       clauseSegments[pos[sharedClause]] = shareds.size();
738     } else if (clauseKeyword == "copyin") {
739       if (checkAllowed(copyinClause) ||
740           parseOperandAndTypeList(parser, copyins, copyinTypes))
741         return failure();
742       clauseSegments[pos[copyinClause]] = copyins.size();
743     } else if (clauseKeyword == "allocate") {
744       if (checkAllowed(allocateClause) ||
745           parseAllocateAndAllocator(parser, allocates, allocateTypes,
746                                     allocators, allocatorTypes))
747         return failure();
748       clauseSegments[pos[allocateClause]] = allocates.size();
749       clauseSegments[pos[allocateClause] + 1] = allocators.size();
750     } else if (clauseKeyword == "default") {
751       StringRef defval;
752       SMLoc loc = parser.getCurrentLocation();
753       if (checkAllowed(defaultClause) || parser.parseLParen() ||
754           parser.parseKeyword(&defval) || parser.parseRParen())
755         return failure();
756       // The def prefix is required for the attribute as "private" is a keyword
757       // in C++.
758       if (Optional<ClauseDefault> def =
759               symbolizeClauseDefault(("def" + defval).str())) {
760         result.addAttribute("default_val",
761                             ClauseDefaultAttr::get(parser.getContext(), *def));
762       } else {
763         return parser.emitError(loc, "invalid default clause");
764       }
765     } else if (clauseKeyword == "proc_bind") {
766       if (checkAllowed(procBindClause) ||
767           parseClauseAttr<ClauseProcBindKindAttr>(parser, result,
768                                                   "proc_bind_val", "proc bind"))
769         return failure();
770     } else if (clauseKeyword == "reduction") {
771       if (checkAllowed(reductionClause) ||
772           parseReductionVarList(parser, reductionSymbols, reductionVars,
773                                 reductionVarTypes))
774         return failure();
775       clauseSegments[pos[reductionClause]] = reductionVars.size();
776     } else if (clauseKeyword == "nowait") {
777       if (checkAllowed(nowaitClause))
778         return failure();
779       auto attr = UnitAttr::get(parser.getBuilder().getContext());
780       result.addAttribute("nowait", attr);
781     } else if (clauseKeyword == "linear") {
782       if (checkAllowed(linearClause) ||
783           parseLinearClause(parser, linears, linearTypes, linearSteps))
784         return failure();
785       clauseSegments[pos[linearClause]] = linears.size();
786       clauseSegments[pos[linearClause] + 1] = linearSteps.size();
787     } else if (clauseKeyword == "schedule") {
788       if (checkAllowed(scheduleClause) ||
789           parseScheduleClause(parser, schedule, modifiers, scheduleChunkSize,
790                               scheduleChunkType))
791         return failure();
792       if (scheduleChunkSize) {
793         clauseSegments[pos[scheduleClause]] = 1;
794       }
795     } else if (clauseKeyword == "collapse") {
796       auto type = parser.getBuilder().getI64Type();
797       mlir::IntegerAttr attr;
798       if (checkAllowed(collapseClause) || parser.parseLParen() ||
799           parser.parseAttribute(attr, type) || parser.parseRParen())
800         return failure();
801       result.addAttribute("collapse_val", attr);
802     } else if (clauseKeyword == "ordered") {
803       mlir::IntegerAttr attr;
804       if (checkAllowed(orderedClause))
805         return failure();
806       if (succeeded(parser.parseOptionalLParen())) {
807         auto type = parser.getBuilder().getI64Type();
808         if (parser.parseAttribute(attr, type) || parser.parseRParen())
809           return failure();
810       } else {
811         // Use 0 to represent no ordered parameter was specified
812         attr = parser.getBuilder().getI64IntegerAttr(0);
813       }
814       result.addAttribute("ordered_val", attr);
815     } else if (clauseKeyword == "order") {
816       if (checkAllowed(orderClause) ||
817           parseClauseAttr<ClauseOrderKindAttr>(parser, result, "order_val",
818                                                "order"))
819         return failure();
820     } else if (clauseKeyword == "memory_order") {
821       if (checkAllowed(memoryOrderClause) ||
822           parseClauseAttr<ClauseMemoryOrderKindAttr>(
823               parser, result, "memory_order", "memory order"))
824         return failure();
825     } else if (clauseKeyword == "hint") {
826       IntegerAttr hint;
827       if (checkAllowed(hintClause) ||
828           parseSynchronizationHint(parser, hint, false))
829         return failure();
830       result.addAttribute("hint", hint);
831     } else {
832       return parser.emitError(parser.getNameLoc())
833              << clauseKeyword << " is not a valid clause";
834     }
835   }
836 
837   // Add if parameter.
838   if (done[ifClause] && clauseSegments[pos[ifClause]] &&
839       failed(
840           parser.resolveOperand(ifCond.first, ifCond.second, result.operands)))
841     return failure();
842 
843   // Add num_threads parameter.
844   if (done[numThreadsClause] && clauseSegments[pos[numThreadsClause]] &&
845       failed(parser.resolveOperand(numThreads.first, numThreads.second,
846                                    result.operands)))
847     return failure();
848 
849   // Add device parameter.
850   if (done[deviceClause] && clauseSegments[pos[deviceClause]] &&
851       failed(
852           parser.resolveOperand(device.first, device.second, result.operands)))
853     return failure();
854 
855   // Add thread_limit parameter.
856   if (done[threadLimitClause] && clauseSegments[pos[threadLimitClause]] &&
857       failed(parser.resolveOperand(threadLimit.first, threadLimit.second,
858                                    result.operands)))
859     return failure();
860 
861   // Add private parameters.
862   if (done[privateClause] && clauseSegments[pos[privateClause]] &&
863       failed(parser.resolveOperands(privates, privateTypes,
864                                     privates[0].location, result.operands)))
865     return failure();
866 
867   // Add firstprivate parameters.
868   if (done[firstprivateClause] && clauseSegments[pos[firstprivateClause]] &&
869       failed(parser.resolveOperands(firstprivates, firstprivateTypes,
870                                     firstprivates[0].location,
871                                     result.operands)))
872     return failure();
873 
874   // Add lastprivate parameters.
875   if (done[lastprivateClause] && clauseSegments[pos[lastprivateClause]] &&
876       failed(parser.resolveOperands(lastprivates, lastprivateTypes,
877                                     lastprivates[0].location, result.operands)))
878     return failure();
879 
880   // Add shared parameters.
881   if (done[sharedClause] && clauseSegments[pos[sharedClause]] &&
882       failed(parser.resolveOperands(shareds, sharedTypes, shareds[0].location,
883                                     result.operands)))
884     return failure();
885 
886   // Add copyin parameters.
887   if (done[copyinClause] && clauseSegments[pos[copyinClause]] &&
888       failed(parser.resolveOperands(copyins, copyinTypes, copyins[0].location,
889                                     result.operands)))
890     return failure();
891 
892   // Add allocate parameters.
893   if (done[allocateClause] && clauseSegments[pos[allocateClause]] &&
894       failed(parser.resolveOperands(allocates, allocateTypes,
895                                     allocates[0].location, result.operands)))
896     return failure();
897 
898   // Add allocator parameters.
899   if (done[allocateClause] && clauseSegments[pos[allocateClause] + 1] &&
900       failed(parser.resolveOperands(allocators, allocatorTypes,
901                                     allocators[0].location, result.operands)))
902     return failure();
903 
904   // Add reduction parameters and symbols
905   if (done[reductionClause] && clauseSegments[pos[reductionClause]]) {
906     if (failed(parser.resolveOperands(reductionVars, reductionVarTypes,
907                                       parser.getNameLoc(), result.operands)))
908       return failure();
909 
910     SmallVector<Attribute> reductions(reductionSymbols.begin(),
911                                       reductionSymbols.end());
912     result.addAttribute("reductions",
913                         parser.getBuilder().getArrayAttr(reductions));
914   }
915 
916   // Add linear parameters
917   if (done[linearClause] && clauseSegments[pos[linearClause]]) {
918     auto linearStepType = parser.getBuilder().getI32Type();
919     SmallVector<Type> linearStepTypes(linearSteps.size(), linearStepType);
920     if (failed(parser.resolveOperands(linears, linearTypes, linears[0].location,
921                                       result.operands)) ||
922         failed(parser.resolveOperands(linearSteps, linearStepTypes,
923                                       linearSteps[0].location,
924                                       result.operands)))
925       return failure();
926   }
927 
928   // Add schedule parameters
929   if (done[scheduleClause] && !schedule.empty()) {
930     schedule[0] = llvm::toUpper(schedule[0]);
931     if (Optional<ClauseScheduleKind> sched =
932             symbolizeClauseScheduleKind(schedule)) {
933       auto attr = ClauseScheduleKindAttr::get(parser.getContext(), *sched);
934       result.addAttribute("schedule_val", attr);
935     } else {
936       return parser.emitError(parser.getCurrentLocation(),
937                               "invalid schedule kind");
938     }
939     if (!modifiers.empty()) {
940       SMLoc loc = parser.getCurrentLocation();
941       if (Optional<ScheduleModifier> mod =
942               symbolizeScheduleModifier(modifiers[0])) {
943         result.addAttribute(
944             "schedule_modifier",
945             ScheduleModifierAttr::get(parser.getContext(), *mod));
946       } else {
947         return parser.emitError(loc, "invalid schedule modifier");
948       }
949       // Only SIMD attribute is allowed here!
950       if (modifiers.size() > 1) {
951         assert(symbolizeScheduleModifier(modifiers[1]) ==
952                ScheduleModifier::simd);
953         auto attr = UnitAttr::get(parser.getBuilder().getContext());
954         result.addAttribute("simd_modifier", attr);
955       }
956     }
957     if (scheduleChunkSize)
958       parser.resolveOperand(*scheduleChunkSize, scheduleChunkType,
959                             result.operands);
960   }
961 
962   segments.insert(segments.end(), clauseSegments.begin(), clauseSegments.end());
963 
964   return success();
965 }
966 
967 /// Parses a parallel operation.
968 ///
969 /// operation ::= `omp.parallel` clause-list
970 /// clause-list ::= clause | clause clause-list
971 /// clause ::= if | num-threads | private | firstprivate | shared | copyin |
972 ///            allocate | default | proc-bind
973 ///
974 static ParseResult parseParallelOp(OpAsmParser &parser,
975                                    OperationState &result) {
976   SmallVector<ClauseType> clauses = {
977       ifClause,           numThreadsClause, privateClause,
978       firstprivateClause, sharedClause,     copyinClause,
979       allocateClause,     defaultClause,    procBindClause};
980 
981   SmallVector<int> segments;
982 
983   if (failed(parseClauses(parser, result, clauses, segments)))
984     return failure();
985 
986   result.addAttribute("operand_segment_sizes",
987                       parser.getBuilder().getI32VectorAttr(segments));
988 
989   Region *body = result.addRegion();
990   SmallVector<OpAsmParser::OperandType> regionArgs;
991   SmallVector<Type> regionArgTypes;
992   if (parser.parseRegion(*body, regionArgs, regionArgTypes))
993     return failure();
994   return success();
995 }
996 
997 /// Parses a target operation.
998 ///
999 /// operation ::= `omp.target` clause-list
1000 /// clause-list ::= clause | clause clause-list
1001 /// clause ::= if | device | thread_limit | nowait
1002 ///
1003 static ParseResult parseTargetOp(OpAsmParser &parser, OperationState &result) {
1004   SmallVector<ClauseType> clauses = {ifClause, deviceClause, threadLimitClause,
1005                                      nowaitClause};
1006 
1007   SmallVector<int> segments;
1008 
1009   if (failed(parseClauses(parser, result, clauses, segments)))
1010     return failure();
1011 
1012   result.addAttribute(
1013       TargetOp::AttrSizedOperandSegments::getOperandSegmentSizeAttr(),
1014       parser.getBuilder().getI32VectorAttr(segments));
1015 
1016   Region *body = result.addRegion();
1017   SmallVector<OpAsmParser::OperandType> regionArgs;
1018   SmallVector<Type> regionArgTypes;
1019   if (parser.parseRegion(*body, regionArgs, regionArgTypes))
1020     return failure();
1021   return success();
1022 }
1023 
1024 //===----------------------------------------------------------------------===//
1025 // Parser, printer and verifier for SectionsOp
1026 //===----------------------------------------------------------------------===//
1027 
1028 /// Parses an OpenMP Sections operation
1029 ///
1030 /// sections ::= `omp.sections` clause-list
1031 /// clause-list ::= clause clause-list | empty
1032 /// clause ::= private | firstprivate | lastprivate | reduction | allocate |
1033 ///            nowait
1034 static ParseResult parseSectionsOp(OpAsmParser &parser,
1035                                    OperationState &result) {
1036 
1037   SmallVector<ClauseType> clauses = {privateClause,     firstprivateClause,
1038                                      lastprivateClause, reductionClause,
1039                                      allocateClause,    nowaitClause};
1040 
1041   SmallVector<int> segments;
1042 
1043   if (failed(parseClauses(parser, result, clauses, segments)))
1044     return failure();
1045 
1046   result.addAttribute("operand_segment_sizes",
1047                       parser.getBuilder().getI32VectorAttr(segments));
1048 
1049   // Now parse the body.
1050   Region *body = result.addRegion();
1051   if (parser.parseRegion(*body))
1052     return failure();
1053   return success();
1054 }
1055 
1056 static void printSectionsOp(OpAsmPrinter &p, SectionsOp op) {
1057   p << " ";
1058   printDataVars(p, op.private_vars(), "private");
1059   printDataVars(p, op.firstprivate_vars(), "firstprivate");
1060   printDataVars(p, op.lastprivate_vars(), "lastprivate");
1061 
1062   if (!op.reduction_vars().empty())
1063     printReductionVarList(p, op.reductions(), op.reduction_vars());
1064 
1065   if (!op.allocate_vars().empty())
1066     printAllocateAndAllocator(p, op.allocate_vars(), op.allocators_vars());
1067 
1068   if (op.nowait())
1069     p << "nowait";
1070 
1071   p << ' ';
1072   p.printRegion(op.region());
1073 }
1074 
1075 static LogicalResult verifySectionsOp(SectionsOp op) {
1076 
1077   // A list item may not appear in more than one clause on the same directive,
1078   // except that it may be specified in both firstprivate and lastprivate
1079   // clauses.
1080   for (auto var : op.private_vars()) {
1081     if (llvm::is_contained(op.firstprivate_vars(), var))
1082       return op.emitOpError()
1083              << "operand used in both private and firstprivate clauses";
1084     if (llvm::is_contained(op.lastprivate_vars(), var))
1085       return op.emitOpError()
1086              << "operand used in both private and lastprivate clauses";
1087   }
1088 
1089   if (op.allocate_vars().size() != op.allocators_vars().size())
1090     return op.emitError(
1091         "expected equal sizes for allocate and allocator variables");
1092 
1093   for (auto &inst : *op.region().begin()) {
1094     if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst)))
1095       op.emitOpError()
1096           << "expected omp.section op or terminator op inside region";
1097   }
1098 
1099   return verifyReductionVarList(op, op.reductions(), op.reduction_vars());
1100 }
1101 
1102 /// Parses an OpenMP Workshare Loop operation
1103 ///
1104 /// wsloop ::= `omp.wsloop` loop-control clause-list
1105 /// loop-control ::= `(` ssa-id-list `)` `:` type `=`  loop-bounds
1106 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` inclusive? steps
1107 /// steps := `step` `(`ssa-id-list`)`
1108 /// clause-list ::= clause clause-list | empty
1109 /// clause ::= private | firstprivate | lastprivate | linear | schedule |
1110 //             collapse | nowait | ordered | order | reduction
1111 static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) {
1112 
1113   // Parse an opening `(` followed by induction variables followed by `)`
1114   SmallVector<OpAsmParser::OperandType> ivs;
1115   if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
1116                                      OpAsmParser::Delimiter::Paren))
1117     return failure();
1118 
1119   int numIVs = static_cast<int>(ivs.size());
1120   Type loopVarType;
1121   if (parser.parseColonType(loopVarType))
1122     return failure();
1123 
1124   // Parse loop bounds.
1125   SmallVector<OpAsmParser::OperandType> lower;
1126   if (parser.parseEqual() ||
1127       parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) ||
1128       parser.resolveOperands(lower, loopVarType, result.operands))
1129     return failure();
1130 
1131   SmallVector<OpAsmParser::OperandType> upper;
1132   if (parser.parseKeyword("to") ||
1133       parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) ||
1134       parser.resolveOperands(upper, loopVarType, result.operands))
1135     return failure();
1136 
1137   if (succeeded(parser.parseOptionalKeyword("inclusive"))) {
1138     auto attr = UnitAttr::get(parser.getBuilder().getContext());
1139     result.addAttribute("inclusive", attr);
1140   }
1141 
1142   // Parse step values.
1143   SmallVector<OpAsmParser::OperandType> steps;
1144   if (parser.parseKeyword("step") ||
1145       parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) ||
1146       parser.resolveOperands(steps, loopVarType, result.operands))
1147     return failure();
1148 
1149   SmallVector<ClauseType> clauses = {
1150       privateClause,   firstprivateClause, lastprivateClause, linearClause,
1151       reductionClause, collapseClause,     orderClause,       orderedClause,
1152       nowaitClause,    scheduleClause};
1153   SmallVector<int> segments{numIVs, numIVs, numIVs};
1154   if (failed(parseClauses(parser, result, clauses, segments)))
1155     return failure();
1156 
1157   result.addAttribute("operand_segment_sizes",
1158                       parser.getBuilder().getI32VectorAttr(segments));
1159 
1160   // Now parse the body.
1161   Region *body = result.addRegion();
1162   SmallVector<Type> ivTypes(numIVs, loopVarType);
1163   SmallVector<OpAsmParser::OperandType> blockArgs(ivs);
1164   if (parser.parseRegion(*body, blockArgs, ivTypes))
1165     return failure();
1166   return success();
1167 }
1168 
1169 static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) {
1170   auto args = op.getRegion().front().getArguments();
1171   p << " (" << args << ") : " << args[0].getType() << " = (" << op.lowerBound()
1172     << ") to (" << op.upperBound() << ") ";
1173   if (op.inclusive()) {
1174     p << "inclusive ";
1175   }
1176   p << "step (" << op.step() << ") ";
1177 
1178   printDataVars(p, op.private_vars(), "private");
1179   printDataVars(p, op.firstprivate_vars(), "firstprivate");
1180   printDataVars(p, op.lastprivate_vars(), "lastprivate");
1181 
1182   if (!op.linear_vars().empty())
1183     printLinearClause(p, op.linear_vars(), op.linear_step_vars());
1184 
1185   if (auto sched = op.schedule_val())
1186     printScheduleClause(p, sched.getValue(), op.schedule_modifier(),
1187                         op.simd_modifier(), op.schedule_chunk_var());
1188 
1189   if (auto collapse = op.collapse_val())
1190     p << "collapse(" << collapse << ") ";
1191 
1192   if (op.nowait())
1193     p << "nowait ";
1194 
1195   if (auto ordered = op.ordered_val())
1196     p << "ordered(" << ordered << ") ";
1197 
1198   if (auto order = op.order_val())
1199     p << "order(" << stringifyClauseOrderKind(*order) << ") ";
1200 
1201   if (!op.reduction_vars().empty())
1202     printReductionVarList(p, op.reductions(), op.reduction_vars());
1203 
1204   p << ' ';
1205   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
1206 }
1207 
1208 //===----------------------------------------------------------------------===//
1209 // ReductionOp
1210 //===----------------------------------------------------------------------===//
1211 
1212 static ParseResult parseAtomicReductionRegion(OpAsmParser &parser,
1213                                               Region &region) {
1214   if (parser.parseOptionalKeyword("atomic"))
1215     return success();
1216   return parser.parseRegion(region);
1217 }
1218 
1219 static void printAtomicReductionRegion(OpAsmPrinter &printer,
1220                                        ReductionDeclareOp op, Region &region) {
1221   if (region.empty())
1222     return;
1223   printer << "atomic ";
1224   printer.printRegion(region);
1225 }
1226 
1227 static LogicalResult verifyReductionDeclareOp(ReductionDeclareOp op) {
1228   if (op.initializerRegion().empty())
1229     return op.emitOpError() << "expects non-empty initializer region";
1230   Block &initializerEntryBlock = op.initializerRegion().front();
1231   if (initializerEntryBlock.getNumArguments() != 1 ||
1232       initializerEntryBlock.getArgument(0).getType() != op.type()) {
1233     return op.emitOpError() << "expects initializer region with one argument "
1234                                "of the reduction type";
1235   }
1236 
1237   for (YieldOp yieldOp : op.initializerRegion().getOps<YieldOp>()) {
1238     if (yieldOp.results().size() != 1 ||
1239         yieldOp.results().getTypes()[0] != op.type())
1240       return op.emitOpError() << "expects initializer region to yield a value "
1241                                  "of the reduction type";
1242   }
1243 
1244   if (op.reductionRegion().empty())
1245     return op.emitOpError() << "expects non-empty reduction region";
1246   Block &reductionEntryBlock = op.reductionRegion().front();
1247   if (reductionEntryBlock.getNumArguments() != 2 ||
1248       reductionEntryBlock.getArgumentTypes()[0] !=
1249           reductionEntryBlock.getArgumentTypes()[1] ||
1250       reductionEntryBlock.getArgumentTypes()[0] != op.type())
1251     return op.emitOpError() << "expects reduction region with two arguments of "
1252                                "the reduction type";
1253   for (YieldOp yieldOp : op.reductionRegion().getOps<YieldOp>()) {
1254     if (yieldOp.results().size() != 1 ||
1255         yieldOp.results().getTypes()[0] != op.type())
1256       return op.emitOpError() << "expects reduction region to yield a value "
1257                                  "of the reduction type";
1258   }
1259 
1260   if (op.atomicReductionRegion().empty())
1261     return success();
1262 
1263   Block &atomicReductionEntryBlock = op.atomicReductionRegion().front();
1264   if (atomicReductionEntryBlock.getNumArguments() != 2 ||
1265       atomicReductionEntryBlock.getArgumentTypes()[0] !=
1266           atomicReductionEntryBlock.getArgumentTypes()[1])
1267     return op.emitOpError() << "expects atomic reduction region with two "
1268                                "arguments of the same type";
1269   auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0]
1270                      .dyn_cast<PointerLikeType>();
1271   if (!ptrType || ptrType.getElementType() != op.type())
1272     return op.emitOpError() << "expects atomic reduction region arguments to "
1273                                "be accumulators containing the reduction type";
1274   return success();
1275 }
1276 
1277 static LogicalResult verifyReductionOp(ReductionOp op) {
1278   // TODO: generalize this to an op interface when there is more than one op
1279   // that supports reductions.
1280   auto container = op->getParentOfType<WsLoopOp>();
1281   for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i)
1282     if (container.reduction_vars()[i] == op.accumulator())
1283       return success();
1284 
1285   return op.emitOpError() << "the accumulator is not used by the parent";
1286 }
1287 
1288 //===----------------------------------------------------------------------===//
1289 // WsLoopOp
1290 //===----------------------------------------------------------------------===//
1291 
1292 void WsLoopOp::build(OpBuilder &builder, OperationState &state,
1293                      ValueRange lowerBound, ValueRange upperBound,
1294                      ValueRange step, ArrayRef<NamedAttribute> attributes) {
1295   build(builder, state, TypeRange(), lowerBound, upperBound, step,
1296         /*privateVars=*/ValueRange(),
1297         /*firstprivateVars=*/ValueRange(), /*lastprivate_vars=*/ValueRange(),
1298         /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
1299         /*reduction_vars=*/ValueRange(), /*schedule_val=*/nullptr,
1300         /*schedule_chunk_var=*/nullptr, /*collapse_val=*/nullptr,
1301         /*nowait=*/nullptr, /*ordered_val=*/nullptr, /*order_val=*/nullptr,
1302         /*inclusive=*/nullptr, /*buildBody=*/false);
1303   state.addAttributes(attributes);
1304 }
1305 
1306 void WsLoopOp::build(OpBuilder &, OperationState &state, TypeRange resultTypes,
1307                      ValueRange operands, ArrayRef<NamedAttribute> attributes) {
1308   state.addOperands(operands);
1309   state.addAttributes(attributes);
1310   (void)state.addRegion();
1311   assert(resultTypes.empty() && "mismatched number of return types");
1312   state.addTypes(resultTypes);
1313 }
1314 
1315 void WsLoopOp::build(OpBuilder &builder, OperationState &result,
1316                      TypeRange typeRange, ValueRange lowerBounds,
1317                      ValueRange upperBounds, ValueRange steps,
1318                      ValueRange privateVars, ValueRange firstprivateVars,
1319                      ValueRange lastprivateVars, ValueRange linearVars,
1320                      ValueRange linearStepVars, ValueRange reductionVars,
1321                      StringAttr scheduleVal, Value scheduleChunkVar,
1322                      IntegerAttr collapseVal, UnitAttr nowait,
1323                      IntegerAttr orderedVal, StringAttr orderVal,
1324                      UnitAttr inclusive, bool buildBody) {
1325   result.addOperands(lowerBounds);
1326   result.addOperands(upperBounds);
1327   result.addOperands(steps);
1328   result.addOperands(privateVars);
1329   result.addOperands(firstprivateVars);
1330   result.addOperands(linearVars);
1331   result.addOperands(linearStepVars);
1332   if (scheduleChunkVar)
1333     result.addOperands(scheduleChunkVar);
1334 
1335   if (scheduleVal)
1336     result.addAttribute("schedule_val", scheduleVal);
1337   if (collapseVal)
1338     result.addAttribute("collapse_val", collapseVal);
1339   if (nowait)
1340     result.addAttribute("nowait", nowait);
1341   if (orderedVal)
1342     result.addAttribute("ordered_val", orderedVal);
1343   if (orderVal)
1344     result.addAttribute("order", orderVal);
1345   if (inclusive)
1346     result.addAttribute("inclusive", inclusive);
1347   result.addAttribute(
1348       WsLoopOp::getOperandSegmentSizeAttr(),
1349       builder.getI32VectorAttr(
1350           {static_cast<int32_t>(lowerBounds.size()),
1351            static_cast<int32_t>(upperBounds.size()),
1352            static_cast<int32_t>(steps.size()),
1353            static_cast<int32_t>(privateVars.size()),
1354            static_cast<int32_t>(firstprivateVars.size()),
1355            static_cast<int32_t>(lastprivateVars.size()),
1356            static_cast<int32_t>(linearVars.size()),
1357            static_cast<int32_t>(linearStepVars.size()),
1358            static_cast<int32_t>(reductionVars.size()),
1359            static_cast<int32_t>(scheduleChunkVar != nullptr ? 1 : 0)}));
1360 
1361   Region *bodyRegion = result.addRegion();
1362   if (buildBody) {
1363     OpBuilder::InsertionGuard guard(builder);
1364     unsigned numIVs = steps.size();
1365     SmallVector<Type, 8> argTypes(numIVs, steps.getType().front());
1366     SmallVector<Location, 8> argLocs(numIVs, result.location);
1367     builder.createBlock(bodyRegion, {}, argTypes, argLocs);
1368   }
1369 }
1370 
1371 static LogicalResult verifyWsLoopOp(WsLoopOp op) {
1372   return verifyReductionVarList(op, op.reductions(), op.reduction_vars());
1373 }
1374 
1375 //===----------------------------------------------------------------------===//
1376 // Verifier for critical construct (2.17.1)
1377 //===----------------------------------------------------------------------===//
1378 
1379 static LogicalResult verifyCriticalDeclareOp(CriticalDeclareOp op) {
1380   return verifySynchronizationHint(op, op.hint());
1381 }
1382 
1383 static LogicalResult verifyCriticalOp(CriticalOp op) {
1384 
1385   if (op.nameAttr()) {
1386     auto symbolRef = op.nameAttr().cast<SymbolRefAttr>();
1387     auto decl =
1388         SymbolTable::lookupNearestSymbolFrom<CriticalDeclareOp>(op, symbolRef);
1389     if (!decl) {
1390       return op.emitOpError() << "expected symbol reference " << symbolRef
1391                               << " to point to a critical declaration";
1392     }
1393   }
1394 
1395   return success();
1396 }
1397 
1398 //===----------------------------------------------------------------------===//
1399 // Verifier for ordered construct
1400 //===----------------------------------------------------------------------===//
1401 
1402 static LogicalResult verifyOrderedOp(OrderedOp op) {
1403   auto container = op->getParentOfType<WsLoopOp>();
1404   if (!container || !container.ordered_valAttr() ||
1405       container.ordered_valAttr().getInt() == 0)
1406     return op.emitOpError() << "ordered depend directive must be closely "
1407                             << "nested inside a worksharing-loop with ordered "
1408                             << "clause with parameter present";
1409 
1410   if (container.ordered_valAttr().getInt() !=
1411       (int64_t)op.num_loops_val().getValue())
1412     return op.emitOpError() << "number of variables in depend clause does not "
1413                             << "match number of iteration variables in the "
1414                             << "doacross loop";
1415 
1416   return success();
1417 }
1418 
1419 static LogicalResult verifyOrderedRegionOp(OrderedRegionOp op) {
1420   // TODO: The code generation for ordered simd directive is not supported yet.
1421   if (op.simd())
1422     return failure();
1423 
1424   if (auto container = op->getParentOfType<WsLoopOp>()) {
1425     if (!container.ordered_valAttr() ||
1426         container.ordered_valAttr().getInt() != 0)
1427       return op.emitOpError() << "ordered region must be closely nested inside "
1428                               << "a worksharing-loop region with an ordered "
1429                               << "clause without parameter present";
1430   }
1431 
1432   return success();
1433 }
1434 
1435 //===----------------------------------------------------------------------===//
1436 // AtomicReadOp
1437 //===----------------------------------------------------------------------===//
1438 
1439 /// Parser for AtomicReadOp
1440 ///
1441 /// operation ::= `omp.atomic.read` atomic-clause-list address `->` result-type
1442 /// address ::= operand `:` type
1443 static ParseResult parseAtomicReadOp(OpAsmParser &parser,
1444                                      OperationState &result) {
1445   OpAsmParser::OperandType x, v;
1446   Type addressType;
1447   SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
1448   SmallVector<int> segments;
1449 
1450   if (parser.parseOperand(v) || parser.parseEqual() || parser.parseOperand(x) ||
1451       parseClauses(parser, result, clauses, segments) ||
1452       parser.parseColonType(addressType) ||
1453       parser.resolveOperand(x, addressType, result.operands) ||
1454       parser.resolveOperand(v, addressType, result.operands))
1455     return failure();
1456 
1457   return success();
1458 }
1459 
1460 /// Printer for AtomicReadOp
1461 static void printAtomicReadOp(OpAsmPrinter &p, AtomicReadOp op) {
1462   p << " " << op.v() << " = " << op.x() << " ";
1463   if (auto mo = op.memory_order())
1464     p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") ";
1465   if (op.hintAttr())
1466     printSynchronizationHint(p << " ", op, op.hintAttr());
1467   p << ": " << op.x().getType();
1468 }
1469 
1470 /// Verifier for AtomicReadOp
1471 static LogicalResult verifyAtomicReadOp(AtomicReadOp op) {
1472   if (auto mo = op.memory_order()) {
1473     if (*mo == ClauseMemoryOrderKind::acq_rel ||
1474         *mo == ClauseMemoryOrderKind::release) {
1475       return op.emitError(
1476           "memory-order must not be acq_rel or release for atomic reads");
1477     }
1478   }
1479   if (op.x() == op.v())
1480     return op.emitError(
1481         "read and write must not be to the same location for atomic reads");
1482   return verifySynchronizationHint(op, op.hint());
1483 }
1484 
1485 //===----------------------------------------------------------------------===//
1486 // AtomicWriteOp
1487 //===----------------------------------------------------------------------===//
1488 
1489 /// Parser for AtomicWriteOp
1490 ///
1491 /// operation ::= `omp.atomic.write` atomic-clause-list operands
1492 /// operands ::= address `,` value
1493 /// address ::= operand `:` type
1494 /// value ::= operand `:` type
1495 static ParseResult parseAtomicWriteOp(OpAsmParser &parser,
1496                                       OperationState &result) {
1497   OpAsmParser::OperandType address, value;
1498   Type addrType, valueType;
1499   SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
1500   SmallVector<int> segments;
1501 
1502   if (parser.parseOperand(address) || parser.parseEqual() ||
1503       parser.parseOperand(value) ||
1504       parseClauses(parser, result, clauses, segments) ||
1505       parser.parseColonType(addrType) || parser.parseComma() ||
1506       parser.parseType(valueType) ||
1507       parser.resolveOperand(address, addrType, result.operands) ||
1508       parser.resolveOperand(value, valueType, result.operands))
1509     return failure();
1510   return success();
1511 }
1512 
1513 /// Printer for AtomicWriteOp
1514 static void printAtomicWriteOp(OpAsmPrinter &p, AtomicWriteOp op) {
1515   p << " " << op.address() << " = " << op.value() << " ";
1516   if (auto mo = op.memory_order())
1517     p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") ";
1518   if (op.hintAttr())
1519     printSynchronizationHint(p, op, op.hintAttr());
1520   p << ": " << op.address().getType() << ", " << op.value().getType();
1521 }
1522 
1523 /// Verifier for AtomicWriteOp
1524 static LogicalResult verifyAtomicWriteOp(AtomicWriteOp op) {
1525   if (auto mo = op.memory_order()) {
1526     if (*mo == ClauseMemoryOrderKind::acq_rel ||
1527         *mo == ClauseMemoryOrderKind::acquire) {
1528       return op.emitError(
1529           "memory-order must not be acq_rel or acquire for atomic writes");
1530     }
1531   }
1532   return verifySynchronizationHint(op, op.hint());
1533 }
1534 
1535 //===----------------------------------------------------------------------===//
1536 // AtomicUpdateOp
1537 //===----------------------------------------------------------------------===//
1538 
1539 /// Parser for AtomicUpdateOp
1540 ///
1541 /// operation ::= `omp.atomic.update` atomic-clause-list region
1542 static ParseResult parseAtomicUpdateOp(OpAsmParser &parser,
1543                                        OperationState &result) {
1544   SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
1545   SmallVector<int> segments;
1546   OpAsmParser::OperandType x, y, z;
1547   Type xType, exprType;
1548   StringRef binOp;
1549 
1550   // x = y `op` z : xtype, exprtype
1551   if (parser.parseOperand(x) || parser.parseEqual() || parser.parseOperand(y) ||
1552       parser.parseKeyword(&binOp) || parser.parseOperand(z) ||
1553       parseClauses(parser, result, clauses, segments) || parser.parseColon() ||
1554       parser.parseType(xType) || parser.parseComma() ||
1555       parser.parseType(exprType) ||
1556       parser.resolveOperand(x, xType, result.operands)) {
1557     return failure();
1558   }
1559 
1560   auto binOpEnum = AtomicBinOpKindToEnum(binOp.upper());
1561   if (!binOpEnum)
1562     return parser.emitError(parser.getNameLoc())
1563            << "invalid atomic bin op in atomic update\n";
1564   auto attr =
1565       parser.getBuilder().getI64IntegerAttr((int64_t)binOpEnum.getValue());
1566   result.addAttribute("binop", attr);
1567 
1568   OpAsmParser::OperandType expr;
1569   if (x.name == y.name && x.number == y.number) {
1570     expr = z;
1571     result.addAttribute("isXBinopExpr", parser.getBuilder().getUnitAttr());
1572   } else if (x.name == z.name && x.number == z.number) {
1573     expr = y;
1574   } else {
1575     return parser.emitError(parser.getNameLoc())
1576            << "atomic update variable " << x.name
1577            << " not found in the RHS of the assignment statement in an"
1578               " atomic.update operation";
1579   }
1580   return parser.resolveOperand(expr, exprType, result.operands);
1581 }
1582 
1583 /// Printer for AtomicUpdateOp
1584 static void printAtomicUpdateOp(OpAsmPrinter &p, AtomicUpdateOp op) {
1585   p << " " << op.x() << " = ";
1586   Value y, z;
1587   if (op.isXBinopExpr()) {
1588     y = op.x();
1589     z = op.expr();
1590   } else {
1591     y = op.expr();
1592     z = op.x();
1593   }
1594   p << y << " " << AtomicBinOpKindToString(op.binop()).lower() << " " << z
1595     << " ";
1596   if (auto mo = op.memory_order())
1597     p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") ";
1598   if (op.hintAttr())
1599     printSynchronizationHint(p, op, op.hintAttr());
1600   p << ": " << op.x().getType() << ", " << op.expr().getType();
1601 }
1602 
1603 /// Verifier for AtomicUpdateOp
1604 static LogicalResult verifyAtomicUpdateOp(AtomicUpdateOp op) {
1605   if (auto mo = op.memory_order()) {
1606     if (*mo == ClauseMemoryOrderKind::acq_rel ||
1607         *mo == ClauseMemoryOrderKind::acquire) {
1608       return op.emitError(
1609           "memory-order must not be acq_rel or acquire for atomic updates");
1610     }
1611   }
1612   return success();
1613 }
1614 
1615 //===----------------------------------------------------------------------===//
1616 // AtomicCaptureOp
1617 //===----------------------------------------------------------------------===//
1618 
1619 /// Parser for AtomicCaptureOp
1620 static LogicalResult parseAtomicCaptureOp(OpAsmParser &parser,
1621                                           OperationState &result) {
1622   SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
1623   SmallVector<int> segments;
1624   if (parseClauses(parser, result, clauses, segments) ||
1625       parser.parseRegion(*result.addRegion()))
1626     return failure();
1627   return success();
1628 }
1629 
1630 /// Printer for AtomicCaptureOp
1631 static void printAtomicCaptureOp(OpAsmPrinter &p, AtomicCaptureOp op) {
1632   if (op.memory_order())
1633     p << "memory_order(" << op.memory_order() << ") ";
1634   if (op.hintAttr())
1635     printSynchronizationHint(p, op, op.hintAttr());
1636   p.printRegion(op.region());
1637 }
1638 
1639 /// Verifier for AtomicCaptureOp
1640 static LogicalResult verifyAtomicCaptureOp(AtomicCaptureOp op) {
1641   Block::OpListType &ops = op.region().front().getOperations();
1642   if (ops.size() != 3)
1643     return emitError(op.getLoc())
1644            << "expected three operations in omp.atomic.capture region (one "
1645               "terminator, and two atomic ops)";
1646   auto &firstOp = ops.front();
1647   auto &secondOp = *ops.getNextNode(firstOp);
1648   auto firstReadStmt = dyn_cast<AtomicReadOp>(firstOp);
1649   auto firstUpdateStmt = dyn_cast<AtomicUpdateOp>(firstOp);
1650   auto secondReadStmt = dyn_cast<AtomicReadOp>(secondOp);
1651   auto secondUpdateStmt = dyn_cast<AtomicUpdateOp>(secondOp);
1652   auto secondWriteStmt = dyn_cast<AtomicWriteOp>(secondOp);
1653 
1654   if (!((firstUpdateStmt && secondReadStmt) ||
1655         (firstReadStmt && secondUpdateStmt) ||
1656         (firstReadStmt && secondWriteStmt)))
1657     return emitError(ops.front().getLoc())
1658            << "invalid sequence of operations in the capture region";
1659   if (firstUpdateStmt && secondReadStmt &&
1660       firstUpdateStmt.x() != secondReadStmt.x())
1661     return emitError(firstUpdateStmt.getLoc())
1662            << "updated variable in omp.atomic.update must be captured in "
1663               "second operation";
1664   if (firstReadStmt && secondUpdateStmt &&
1665       firstReadStmt.x() != secondUpdateStmt.x())
1666     return emitError(firstReadStmt.getLoc())
1667            << "captured variable in omp.atomic.read must be updated in second "
1668               "operation";
1669   if (firstReadStmt && secondWriteStmt &&
1670       firstReadStmt.x() != secondWriteStmt.address())
1671     return emitError(firstReadStmt.getLoc())
1672            << "captured variable in omp.atomic.read must be updated in "
1673               "second operation";
1674   return success();
1675 }
1676 
1677 #define GET_ATTRDEF_CLASSES
1678 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
1679 
1680 #define GET_OP_CLASSES
1681 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
1682