1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the OpenMP dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/DialectImplementation.h"
18 #include "mlir/IR/OpImplementation.h"
19 #include "mlir/IR/OperationSupport.h"
20 
21 #include "llvm/ADT/BitVector.h"
22 #include "llvm/ADT/SmallString.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/ADT/StringSwitch.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include <cstddef>
28 
29 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
30 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
31 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
32 
33 using namespace mlir;
34 using namespace mlir::omp;
35 
36 namespace {
37 /// Model for pointer-like types that already provide a `getElementType` method.
38 template <typename T>
39 struct PointerLikeModel
40     : public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> {
41   Type getElementType(Type pointer) const {
42     return pointer.cast<T>().getElementType();
43   }
44 };
45 } // namespace
46 
47 void OpenMPDialect::initialize() {
48   addOperations<
49 #define GET_OP_LIST
50 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
51       >();
52   addAttributes<
53 #define GET_ATTRDEF_LIST
54 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
55       >();
56 
57   LLVM::LLVMPointerType::attachInterface<
58       PointerLikeModel<LLVM::LLVMPointerType>>(*getContext());
59   MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext());
60 }
61 
62 //===----------------------------------------------------------------------===//
63 // ParallelOp
64 //===----------------------------------------------------------------------===//
65 
66 void ParallelOp::build(OpBuilder &builder, OperationState &state,
67                        ArrayRef<NamedAttribute> attributes) {
68   ParallelOp::build(
69       builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
70       /*default_val=*/nullptr, /*private_vars=*/ValueRange(),
71       /*firstprivate_vars=*/ValueRange(), /*shared_vars=*/ValueRange(),
72       /*copyin_vars=*/ValueRange(), /*allocate_vars=*/ValueRange(),
73       /*allocators_vars=*/ValueRange(), /*proc_bind_val=*/nullptr);
74   state.addAttributes(attributes);
75 }
76 
77 //===----------------------------------------------------------------------===//
78 // Parser and printer for Operand and type list
79 //===----------------------------------------------------------------------===//
80 
81 /// Parse a list of operands with types.
82 ///
83 /// operand-and-type-list ::= `(` ssa-id-and-type-list `)`
84 /// ssa-id-and-type-list ::= ssa-id-and-type |
85 ///                          ssa-id-and-type `,` ssa-id-and-type-list
86 /// ssa-id-and-type ::= ssa-id `:` type
87 static ParseResult
88 parseOperandAndTypeList(OpAsmParser &parser,
89                         SmallVectorImpl<OpAsmParser::OperandType> &operands,
90                         SmallVectorImpl<Type> &types) {
91   return parser.parseCommaSeparatedList(
92       OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
93         OpAsmParser::OperandType operand;
94         Type type;
95         if (parser.parseOperand(operand) || parser.parseColonType(type))
96           return failure();
97         operands.push_back(operand);
98         types.push_back(type);
99         return success();
100       });
101 }
102 
103 /// Print an operand and type list with parentheses
104 static void printOperandAndTypeList(OpAsmPrinter &p, OperandRange operands) {
105   p << "(";
106   llvm::interleaveComma(
107       operands, p, [&](const Value &v) { p << v << " : " << v.getType(); });
108   p << ") ";
109 }
110 
111 /// Print data variables corresponding to a data-sharing clause `name`
112 static void printDataVars(OpAsmPrinter &p, OperandRange operands,
113                           StringRef name) {
114   if (!operands.empty()) {
115     p << name;
116     printOperandAndTypeList(p, operands);
117   }
118 }
119 
120 //===----------------------------------------------------------------------===//
121 // Parser and printer for Allocate Clause
122 //===----------------------------------------------------------------------===//
123 
124 /// Parse an allocate clause with allocators and a list of operands with types.
125 ///
126 /// allocate ::= `allocate` `(` allocate-operand-list `)`
127 /// allocate-operand-list :: = allocate-operand |
128 ///                            allocator-operand `,` allocate-operand-list
129 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
130 /// ssa-id-and-type ::= ssa-id `:` type
131 static ParseResult parseAllocateAndAllocator(
132     OpAsmParser &parser,
133     SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate,
134     SmallVectorImpl<Type> &typesAllocate,
135     SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator,
136     SmallVectorImpl<Type> &typesAllocator) {
137 
138   return parser.parseCommaSeparatedList(
139       OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
140         OpAsmParser::OperandType operand;
141         Type type;
142         if (parser.parseOperand(operand) || parser.parseColonType(type))
143           return failure();
144         operandsAllocator.push_back(operand);
145         typesAllocator.push_back(type);
146         if (parser.parseArrow())
147           return failure();
148         if (parser.parseOperand(operand) || parser.parseColonType(type))
149           return failure();
150 
151         operandsAllocate.push_back(operand);
152         typesAllocate.push_back(type);
153         return success();
154       });
155 }
156 
157 /// Print allocate clause
158 static void printAllocateAndAllocator(OpAsmPrinter &p,
159                                       OperandRange varsAllocate,
160                                       OperandRange varsAllocator) {
161   p << "allocate(";
162   for (unsigned i = 0; i < varsAllocate.size(); ++i) {
163     std::string separator = i == varsAllocate.size() - 1 ? ") " : ", ";
164     p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> ";
165     p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator;
166   }
167 }
168 
169 static LogicalResult verifyParallelOp(ParallelOp op) {
170   if (op.allocate_vars().size() != op.allocators_vars().size())
171     return op.emitError(
172         "expected equal sizes for allocate and allocator variables");
173   return success();
174 }
175 
176 static void printParallelOp(OpAsmPrinter &p, ParallelOp op) {
177   p << " ";
178   if (auto ifCond = op.if_expr_var())
179     p << "if(" << ifCond << " : " << ifCond.getType() << ") ";
180 
181   if (auto threads = op.num_threads_var())
182     p << "num_threads(" << threads << " : " << threads.getType() << ") ";
183 
184   printDataVars(p, op.private_vars(), "private");
185   printDataVars(p, op.firstprivate_vars(), "firstprivate");
186   printDataVars(p, op.shared_vars(), "shared");
187   printDataVars(p, op.copyin_vars(), "copyin");
188 
189   if (!op.allocate_vars().empty())
190     printAllocateAndAllocator(p, op.allocate_vars(), op.allocators_vars());
191 
192   if (auto def = op.default_val())
193     p << "default(" << stringifyClauseDefault(*def).drop_front(3) << ") ";
194 
195   if (auto bind = op.proc_bind_val())
196     p << "proc_bind(" << stringifyClauseProcBindKind(*bind) << ") ";
197 
198   p << ' ';
199   p.printRegion(op.getRegion());
200 }
201 
202 //===----------------------------------------------------------------------===//
203 // Parser and printer for Linear Clause
204 //===----------------------------------------------------------------------===//
205 
206 /// linear ::= `linear` `(` linear-list `)`
207 /// linear-list := linear-val | linear-val linear-list
208 /// linear-val := ssa-id-and-type `=` ssa-id-and-type
209 static ParseResult
210 parseLinearClause(OpAsmParser &parser,
211                   SmallVectorImpl<OpAsmParser::OperandType> &vars,
212                   SmallVectorImpl<Type> &types,
213                   SmallVectorImpl<OpAsmParser::OperandType> &stepVars) {
214   if (parser.parseLParen())
215     return failure();
216 
217   do {
218     OpAsmParser::OperandType var;
219     Type type;
220     OpAsmParser::OperandType stepVar;
221     if (parser.parseOperand(var) || parser.parseEqual() ||
222         parser.parseOperand(stepVar) || parser.parseColonType(type))
223       return failure();
224 
225     vars.push_back(var);
226     types.push_back(type);
227     stepVars.push_back(stepVar);
228   } while (succeeded(parser.parseOptionalComma()));
229 
230   if (parser.parseRParen())
231     return failure();
232 
233   return success();
234 }
235 
236 /// Print Linear Clause
237 static void printLinearClause(OpAsmPrinter &p, OperandRange linearVars,
238                               OperandRange linearStepVars) {
239   size_t linearVarsSize = linearVars.size();
240   p << "linear(";
241   for (unsigned i = 0; i < linearVarsSize; ++i) {
242     std::string separator = i == linearVarsSize - 1 ? ") " : ", ";
243     p << linearVars[i];
244     if (linearStepVars.size() > i)
245       p << " = " << linearStepVars[i];
246     p << " : " << linearVars[i].getType() << separator;
247   }
248 }
249 
250 //===----------------------------------------------------------------------===//
251 // Parser and printer for Schedule Clause
252 //===----------------------------------------------------------------------===//
253 
254 static ParseResult
255 verifyScheduleModifiers(OpAsmParser &parser,
256                         SmallVectorImpl<SmallString<12>> &modifiers) {
257   if (modifiers.size() > 2)
258     return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)";
259   for (const auto &mod : modifiers) {
260     // Translate the string. If it has no value, then it was not a valid
261     // modifier!
262     auto symbol = symbolizeScheduleModifier(mod);
263     if (!symbol.hasValue())
264       return parser.emitError(parser.getNameLoc())
265              << " unknown modifier type: " << mod;
266   }
267 
268   // If we have one modifier that is "simd", then stick a "none" modiifer in
269   // index 0.
270   if (modifiers.size() == 1) {
271     if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
272       modifiers.push_back(modifiers[0]);
273       modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
274     }
275   } else if (modifiers.size() == 2) {
276     // If there are two modifier:
277     // First modifier should not be simd, second one should be simd
278     if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
279         symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
280       return parser.emitError(parser.getNameLoc())
281              << " incorrect modifier order";
282   }
283   return success();
284 }
285 
286 /// schedule ::= `schedule` `(` sched-list `)`
287 /// sched-list ::= sched-val | sched-val sched-list |
288 ///                sched-val `,` sched-modifier
289 /// sched-val ::= sched-with-chunk | sched-wo-chunk
290 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
291 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
292 /// sched-wo-chunk ::=  `auto` | `runtime`
293 /// sched-modifier ::=  sched-mod-val | sched-mod-val `,` sched-mod-val
294 /// sched-mod-val ::=  `monotonic` | `nonmonotonic` | `simd` | `none`
295 static ParseResult
296 parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule,
297                     SmallVectorImpl<SmallString<12>> &modifiers,
298                     Optional<OpAsmParser::OperandType> &chunkSize,
299                     Type &chunkType) {
300   if (parser.parseLParen())
301     return failure();
302 
303   StringRef keyword;
304   if (parser.parseKeyword(&keyword))
305     return failure();
306 
307   schedule = keyword;
308   if (keyword == "static" || keyword == "dynamic" || keyword == "guided") {
309     if (succeeded(parser.parseOptionalEqual())) {
310       chunkSize = OpAsmParser::OperandType{};
311       if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType))
312         return failure();
313     } else {
314       chunkSize = llvm::NoneType::None;
315     }
316   } else if (keyword == "auto" || keyword == "runtime") {
317     chunkSize = llvm::NoneType::None;
318   } else {
319     return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
320   }
321 
322   // If there is a comma, we have one or more modifiers..
323   while (succeeded(parser.parseOptionalComma())) {
324     StringRef mod;
325     if (parser.parseKeyword(&mod))
326       return failure();
327     modifiers.push_back(mod);
328   }
329 
330   if (parser.parseRParen())
331     return failure();
332 
333   if (verifyScheduleModifiers(parser, modifiers))
334     return failure();
335 
336   return success();
337 }
338 
339 /// Print schedule clause
340 static void printScheduleClause(OpAsmPrinter &p, ClauseScheduleKind sched,
341                                 Optional<ScheduleModifier> modifier, bool simd,
342                                 Value scheduleChunkVar) {
343   p << "schedule(" << stringifyClauseScheduleKind(sched).lower();
344   if (scheduleChunkVar)
345     p << " = " << scheduleChunkVar << " : " << scheduleChunkVar.getType();
346   if (modifier)
347     p << ", " << stringifyScheduleModifier(*modifier);
348   if (simd)
349     p << ", simd";
350   p << ") ";
351 }
352 
353 //===----------------------------------------------------------------------===//
354 // Parser, printer and verifier for ReductionVarList
355 //===----------------------------------------------------------------------===//
356 
357 /// reduction ::= `reduction` `(` reduction-entry-list `)`
358 /// reduction-entry-list ::= reduction-entry
359 ///                        | reduction-entry-list `,` reduction-entry
360 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type
361 static ParseResult
362 parseReductionVarList(OpAsmParser &parser,
363                       SmallVectorImpl<SymbolRefAttr> &symbols,
364                       SmallVectorImpl<OpAsmParser::OperandType> &operands,
365                       SmallVectorImpl<Type> &types) {
366   if (failed(parser.parseLParen()))
367     return failure();
368 
369   do {
370     if (parser.parseAttribute(symbols.emplace_back()) || parser.parseArrow() ||
371         parser.parseOperand(operands.emplace_back()) ||
372         parser.parseColonType(types.emplace_back()))
373       return failure();
374   } while (succeeded(parser.parseOptionalComma()));
375   return parser.parseRParen();
376 }
377 
378 /// Print Reduction clause
379 static void printReductionVarList(OpAsmPrinter &p,
380                                   Optional<ArrayAttr> reductions,
381                                   OperandRange reductionVars) {
382   p << "reduction(";
383   for (unsigned i = 0, e = reductions->size(); i < e; ++i) {
384     if (i != 0)
385       p << ", ";
386     p << (*reductions)[i] << " -> " << reductionVars[i] << " : "
387       << reductionVars[i].getType();
388   }
389   p << ") ";
390 }
391 
392 /// Verifies Reduction Clause
393 static LogicalResult verifyReductionVarList(Operation *op,
394                                             Optional<ArrayAttr> reductions,
395                                             OperandRange reductionVars) {
396   if (!reductionVars.empty()) {
397     if (!reductions || reductions->size() != reductionVars.size())
398       return op->emitOpError()
399              << "expected as many reduction symbol references "
400                 "as reduction variables";
401   } else {
402     if (reductions)
403       return op->emitOpError() << "unexpected reduction symbol references";
404     return success();
405   }
406 
407   DenseSet<Value> accumulators;
408   for (auto args : llvm::zip(reductionVars, *reductions)) {
409     Value accum = std::get<0>(args);
410 
411     if (!accumulators.insert(accum).second)
412       return op->emitOpError() << "accumulator variable used more than once";
413 
414     Type varType = accum.getType().cast<PointerLikeType>();
415     auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>();
416     auto decl =
417         SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef);
418     if (!decl)
419       return op->emitOpError() << "expected symbol reference " << symbolRef
420                                << " to point to a reduction declaration";
421 
422     if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
423       return op->emitOpError()
424              << "expected accumulator (" << varType
425              << ") to be the same type as reduction declaration ("
426              << decl.getAccumulatorType() << ")";
427   }
428 
429   return success();
430 }
431 
432 //===----------------------------------------------------------------------===//
433 // Parser, printer and verifier for Synchronization Hint (2.17.12)
434 //===----------------------------------------------------------------------===//
435 
436 /// Parses a Synchronization Hint clause. The value of hint is an integer
437 /// which is a combination of different hints from `omp_sync_hint_t`.
438 ///
439 /// hint-clause = `hint` `(` hint-value `)`
440 static ParseResult parseSynchronizationHint(OpAsmParser &parser,
441                                             IntegerAttr &hintAttr,
442                                             bool parseKeyword = true) {
443   if (parseKeyword && failed(parser.parseOptionalKeyword("hint"))) {
444     hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
445     return success();
446   }
447 
448   if (failed(parser.parseLParen()))
449     return failure();
450   StringRef hintKeyword;
451   int64_t hint = 0;
452   do {
453     if (failed(parser.parseKeyword(&hintKeyword)))
454       return failure();
455     if (hintKeyword == "uncontended")
456       hint |= 1;
457     else if (hintKeyword == "contended")
458       hint |= 2;
459     else if (hintKeyword == "nonspeculative")
460       hint |= 4;
461     else if (hintKeyword == "speculative")
462       hint |= 8;
463     else
464       return parser.emitError(parser.getCurrentLocation())
465              << hintKeyword << " is not a valid hint";
466   } while (succeeded(parser.parseOptionalComma()));
467   if (failed(parser.parseRParen()))
468     return failure();
469   hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
470   return success();
471 }
472 
473 /// Prints a Synchronization Hint clause
474 static void printSynchronizationHint(OpAsmPrinter &p, Operation *op,
475                                      IntegerAttr hintAttr) {
476   int64_t hint = hintAttr.getInt();
477 
478   if (hint == 0)
479     return;
480 
481   // Helper function to get n-th bit from the right end of `value`
482   auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
483 
484   bool uncontended = bitn(hint, 0);
485   bool contended = bitn(hint, 1);
486   bool nonspeculative = bitn(hint, 2);
487   bool speculative = bitn(hint, 3);
488 
489   SmallVector<StringRef> hints;
490   if (uncontended)
491     hints.push_back("uncontended");
492   if (contended)
493     hints.push_back("contended");
494   if (nonspeculative)
495     hints.push_back("nonspeculative");
496   if (speculative)
497     hints.push_back("speculative");
498 
499   p << "hint(";
500   llvm::interleaveComma(hints, p);
501   p << ") ";
502 }
503 
504 /// Verifies a synchronization hint clause
505 static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
506 
507   // Helper function to get n-th bit from the right end of `value`
508   auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
509 
510   bool uncontended = bitn(hint, 0);
511   bool contended = bitn(hint, 1);
512   bool nonspeculative = bitn(hint, 2);
513   bool speculative = bitn(hint, 3);
514 
515   if (uncontended && contended)
516     return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
517                                 "omp_sync_hint_contended cannot be combined";
518   if (nonspeculative && speculative)
519     return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
520                                 "omp_sync_hint_speculative cannot be combined.";
521   return success();
522 }
523 
524 enum ClauseType {
525   ifClause,
526   numThreadsClause,
527   privateClause,
528   firstprivateClause,
529   lastprivateClause,
530   sharedClause,
531   copyinClause,
532   allocateClause,
533   defaultClause,
534   procBindClause,
535   reductionClause,
536   nowaitClause,
537   linearClause,
538   scheduleClause,
539   collapseClause,
540   orderClause,
541   orderedClause,
542   memoryOrderClause,
543   hintClause,
544   COUNT
545 };
546 
547 //===----------------------------------------------------------------------===//
548 // Parser for Clause List
549 //===----------------------------------------------------------------------===//
550 
551 /// Parse a clause attribute `(` $value `)`.
552 template <typename ClauseAttr>
553 static ParseResult parseClauseAttr(AsmParser &parser, OperationState &state,
554                                    StringRef attrName, StringRef name) {
555   using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
556   StringRef enumStr;
557   llvm::SMLoc loc = parser.getCurrentLocation();
558   if (parser.parseLParen() || parser.parseKeyword(&enumStr) ||
559       parser.parseRParen())
560     return failure();
561   if (Optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
562     auto attr = ClauseAttr::get(parser.getContext(), *enumValue);
563     state.addAttribute(attrName, attr);
564     return success();
565   }
566   return parser.emitError(loc, "invalid ") << name << " kind";
567 }
568 
569 /// Parse a list of clauses. The clauses can appear in any order, but their
570 /// operand segment indices are in the same order that they are passed in the
571 /// `clauses` list. The operand segments are added over the prevSegments
572 
573 /// clause-list ::= clause clause-list | empty
574 /// clause ::= if | num-threads | private | firstprivate | lastprivate |
575 ///            shared | copyin | allocate | default | proc-bind | reduction |
576 ///            nowait | linear | schedule | collapse | order | ordered |
577 ///            inclusive
578 /// if ::= `if` `(` ssa-id-and-type `)`
579 /// num-threads ::= `num_threads` `(` ssa-id-and-type `)`
580 /// private ::= `private` operand-and-type-list
581 /// firstprivate ::= `firstprivate` operand-and-type-list
582 /// lastprivate ::= `lastprivate` operand-and-type-list
583 /// shared ::= `shared` operand-and-type-list
584 /// copyin ::= `copyin` operand-and-type-list
585 /// allocate ::= `allocate` `(` allocate-operand-list `)`
586 /// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`)
587 /// proc-bind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)`
588 /// reduction ::= `reduction` `(` reduction-entry-list `)`
589 /// nowait ::= `nowait`
590 /// linear ::= `linear` `(` linear-list `)`
591 /// schedule ::= `schedule` `(` sched-list `)`
592 /// collapse ::= `collapse` `(` ssa-id-and-type `)`
593 /// order ::= `order` `(` `concurrent` `)`
594 /// ordered ::= `ordered` `(` ssa-id-and-type `)`
595 /// inclusive ::= `inclusive`
596 ///
597 /// Note that each clause can only appear once in the clase-list.
598 static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
599                                 SmallVectorImpl<ClauseType> &clauses,
600                                 SmallVectorImpl<int> &segments) {
601 
602   // Check done[clause] to see if it has been parsed already
603   llvm::BitVector done(ClauseType::COUNT, false);
604 
605   // See pos[clause] to get position of clause in operand segments
606   SmallVector<int> pos(ClauseType::COUNT, -1);
607 
608   // Stores the last parsed clause keyword
609   StringRef clauseKeyword;
610   StringRef opName = result.name.getStringRef();
611 
612   // Containers for storing operands, types and attributes for various clauses
613   std::pair<OpAsmParser::OperandType, Type> ifCond;
614   std::pair<OpAsmParser::OperandType, Type> numThreads;
615 
616   SmallVector<OpAsmParser::OperandType> privates, firstprivates, lastprivates,
617       shareds, copyins;
618   SmallVector<Type> privateTypes, firstprivateTypes, lastprivateTypes,
619       sharedTypes, copyinTypes;
620 
621   SmallVector<OpAsmParser::OperandType> allocates, allocators;
622   SmallVector<Type> allocateTypes, allocatorTypes;
623 
624   SmallVector<SymbolRefAttr> reductionSymbols;
625   SmallVector<OpAsmParser::OperandType> reductionVars;
626   SmallVector<Type> reductionVarTypes;
627 
628   SmallVector<OpAsmParser::OperandType> linears;
629   SmallVector<Type> linearTypes;
630   SmallVector<OpAsmParser::OperandType> linearSteps;
631 
632   SmallString<8> schedule;
633   SmallVector<SmallString<12>> modifiers;
634   Optional<OpAsmParser::OperandType> scheduleChunkSize;
635   Type scheduleChunkType;
636 
637   // Compute the position of clauses in operand segments
638   int currPos = 0;
639   for (ClauseType clause : clauses) {
640 
641     // Skip the following clauses - they do not take any position in operand
642     // segments
643     if (clause == defaultClause || clause == procBindClause ||
644         clause == nowaitClause || clause == collapseClause ||
645         clause == orderClause || clause == orderedClause)
646       continue;
647 
648     pos[clause] = currPos++;
649 
650     // For the following clauses, two positions are reserved in the operand
651     // segments
652     if (clause == allocateClause || clause == linearClause)
653       currPos++;
654   }
655 
656   SmallVector<int> clauseSegments(currPos);
657 
658   // Helper function to check if a clause is allowed/repeated or not
659   auto checkAllowed = [&](ClauseType clause) -> ParseResult {
660     if (!llvm::is_contained(clauses, clause))
661       return parser.emitError(parser.getCurrentLocation())
662              << clauseKeyword << " is not a valid clause for the " << opName
663              << " operation";
664     if (done[clause])
665       return parser.emitError(parser.getCurrentLocation())
666              << "at most one " << clauseKeyword << " clause can appear on the "
667              << opName << " operation";
668     done[clause] = true;
669     return success();
670   };
671 
672   while (succeeded(parser.parseOptionalKeyword(&clauseKeyword))) {
673     if (clauseKeyword == "if") {
674       if (checkAllowed(ifClause) || parser.parseLParen() ||
675           parser.parseOperand(ifCond.first) ||
676           parser.parseColonType(ifCond.second) || parser.parseRParen())
677         return failure();
678       clauseSegments[pos[ifClause]] = 1;
679     } else if (clauseKeyword == "num_threads") {
680       if (checkAllowed(numThreadsClause) || parser.parseLParen() ||
681           parser.parseOperand(numThreads.first) ||
682           parser.parseColonType(numThreads.second) || parser.parseRParen())
683         return failure();
684       clauseSegments[pos[numThreadsClause]] = 1;
685     } else if (clauseKeyword == "private") {
686       if (checkAllowed(privateClause) ||
687           parseOperandAndTypeList(parser, privates, privateTypes))
688         return failure();
689       clauseSegments[pos[privateClause]] = privates.size();
690     } else if (clauseKeyword == "firstprivate") {
691       if (checkAllowed(firstprivateClause) ||
692           parseOperandAndTypeList(parser, firstprivates, firstprivateTypes))
693         return failure();
694       clauseSegments[pos[firstprivateClause]] = firstprivates.size();
695     } else if (clauseKeyword == "lastprivate") {
696       if (checkAllowed(lastprivateClause) ||
697           parseOperandAndTypeList(parser, lastprivates, lastprivateTypes))
698         return failure();
699       clauseSegments[pos[lastprivateClause]] = lastprivates.size();
700     } else if (clauseKeyword == "shared") {
701       if (checkAllowed(sharedClause) ||
702           parseOperandAndTypeList(parser, shareds, sharedTypes))
703         return failure();
704       clauseSegments[pos[sharedClause]] = shareds.size();
705     } else if (clauseKeyword == "copyin") {
706       if (checkAllowed(copyinClause) ||
707           parseOperandAndTypeList(parser, copyins, copyinTypes))
708         return failure();
709       clauseSegments[pos[copyinClause]] = copyins.size();
710     } else if (clauseKeyword == "allocate") {
711       if (checkAllowed(allocateClause) ||
712           parseAllocateAndAllocator(parser, allocates, allocateTypes,
713                                     allocators, allocatorTypes))
714         return failure();
715       clauseSegments[pos[allocateClause]] = allocates.size();
716       clauseSegments[pos[allocateClause] + 1] = allocators.size();
717     } else if (clauseKeyword == "default") {
718       StringRef defval;
719       llvm::SMLoc loc = parser.getCurrentLocation();
720       if (checkAllowed(defaultClause) || parser.parseLParen() ||
721           parser.parseKeyword(&defval) || parser.parseRParen())
722         return failure();
723       // The def prefix is required for the attribute as "private" is a keyword
724       // in C++.
725       if (Optional<ClauseDefault> def =
726               symbolizeClauseDefault(("def" + defval).str())) {
727         result.addAttribute("default_val",
728                             ClauseDefaultAttr::get(parser.getContext(), *def));
729       } else {
730         return parser.emitError(loc, "invalid default clause");
731       }
732     } else if (clauseKeyword == "proc_bind") {
733       if (checkAllowed(procBindClause) ||
734           parseClauseAttr<ClauseProcBindKindAttr>(parser, result,
735                                                   "proc_bind_val", "proc bind"))
736         return failure();
737     } else if (clauseKeyword == "reduction") {
738       if (checkAllowed(reductionClause) ||
739           parseReductionVarList(parser, reductionSymbols, reductionVars,
740                                 reductionVarTypes))
741         return failure();
742       clauseSegments[pos[reductionClause]] = reductionVars.size();
743     } else if (clauseKeyword == "nowait") {
744       if (checkAllowed(nowaitClause))
745         return failure();
746       auto attr = UnitAttr::get(parser.getBuilder().getContext());
747       result.addAttribute("nowait", attr);
748     } else if (clauseKeyword == "linear") {
749       if (checkAllowed(linearClause) ||
750           parseLinearClause(parser, linears, linearTypes, linearSteps))
751         return failure();
752       clauseSegments[pos[linearClause]] = linears.size();
753       clauseSegments[pos[linearClause] + 1] = linearSteps.size();
754     } else if (clauseKeyword == "schedule") {
755       if (checkAllowed(scheduleClause) ||
756           parseScheduleClause(parser, schedule, modifiers, scheduleChunkSize,
757                               scheduleChunkType))
758         return failure();
759       if (scheduleChunkSize) {
760         clauseSegments[pos[scheduleClause]] = 1;
761       }
762     } else if (clauseKeyword == "collapse") {
763       auto type = parser.getBuilder().getI64Type();
764       mlir::IntegerAttr attr;
765       if (checkAllowed(collapseClause) || parser.parseLParen() ||
766           parser.parseAttribute(attr, type) || parser.parseRParen())
767         return failure();
768       result.addAttribute("collapse_val", attr);
769     } else if (clauseKeyword == "ordered") {
770       mlir::IntegerAttr attr;
771       if (checkAllowed(orderedClause))
772         return failure();
773       if (succeeded(parser.parseOptionalLParen())) {
774         auto type = parser.getBuilder().getI64Type();
775         if (parser.parseAttribute(attr, type) || parser.parseRParen())
776           return failure();
777       } else {
778         // Use 0 to represent no ordered parameter was specified
779         attr = parser.getBuilder().getI64IntegerAttr(0);
780       }
781       result.addAttribute("ordered_val", attr);
782     } else if (clauseKeyword == "order") {
783       if (checkAllowed(orderClause) ||
784           parseClauseAttr<ClauseOrderKindAttr>(parser, result, "order_val",
785                                                "order"))
786         return failure();
787     } else if (clauseKeyword == "memory_order") {
788       if (checkAllowed(memoryOrderClause) ||
789           parseClauseAttr<ClauseMemoryOrderKindAttr>(
790               parser, result, "memory_order", "memory order"))
791         return failure();
792     } else if (clauseKeyword == "hint") {
793       IntegerAttr hint;
794       if (checkAllowed(hintClause) ||
795           parseSynchronizationHint(parser, hint, false))
796         return failure();
797       result.addAttribute("hint", hint);
798     } else {
799       return parser.emitError(parser.getNameLoc())
800              << clauseKeyword << " is not a valid clause";
801     }
802   }
803 
804   // Add if parameter.
805   if (done[ifClause] && clauseSegments[pos[ifClause]] &&
806       failed(
807           parser.resolveOperand(ifCond.first, ifCond.second, result.operands)))
808     return failure();
809 
810   // Add num_threads parameter.
811   if (done[numThreadsClause] && clauseSegments[pos[numThreadsClause]] &&
812       failed(parser.resolveOperand(numThreads.first, numThreads.second,
813                                    result.operands)))
814     return failure();
815 
816   // Add private parameters.
817   if (done[privateClause] && clauseSegments[pos[privateClause]] &&
818       failed(parser.resolveOperands(privates, privateTypes,
819                                     privates[0].location, result.operands)))
820     return failure();
821 
822   // Add firstprivate parameters.
823   if (done[firstprivateClause] && clauseSegments[pos[firstprivateClause]] &&
824       failed(parser.resolveOperands(firstprivates, firstprivateTypes,
825                                     firstprivates[0].location,
826                                     result.operands)))
827     return failure();
828 
829   // Add lastprivate parameters.
830   if (done[lastprivateClause] && clauseSegments[pos[lastprivateClause]] &&
831       failed(parser.resolveOperands(lastprivates, lastprivateTypes,
832                                     lastprivates[0].location, result.operands)))
833     return failure();
834 
835   // Add shared parameters.
836   if (done[sharedClause] && clauseSegments[pos[sharedClause]] &&
837       failed(parser.resolveOperands(shareds, sharedTypes, shareds[0].location,
838                                     result.operands)))
839     return failure();
840 
841   // Add copyin parameters.
842   if (done[copyinClause] && clauseSegments[pos[copyinClause]] &&
843       failed(parser.resolveOperands(copyins, copyinTypes, copyins[0].location,
844                                     result.operands)))
845     return failure();
846 
847   // Add allocate parameters.
848   if (done[allocateClause] && clauseSegments[pos[allocateClause]] &&
849       failed(parser.resolveOperands(allocates, allocateTypes,
850                                     allocates[0].location, result.operands)))
851     return failure();
852 
853   // Add allocator parameters.
854   if (done[allocateClause] && clauseSegments[pos[allocateClause] + 1] &&
855       failed(parser.resolveOperands(allocators, allocatorTypes,
856                                     allocators[0].location, result.operands)))
857     return failure();
858 
859   // Add reduction parameters and symbols
860   if (done[reductionClause] && clauseSegments[pos[reductionClause]]) {
861     if (failed(parser.resolveOperands(reductionVars, reductionVarTypes,
862                                       parser.getNameLoc(), result.operands)))
863       return failure();
864 
865     SmallVector<Attribute> reductions(reductionSymbols.begin(),
866                                       reductionSymbols.end());
867     result.addAttribute("reductions",
868                         parser.getBuilder().getArrayAttr(reductions));
869   }
870 
871   // Add linear parameters
872   if (done[linearClause] && clauseSegments[pos[linearClause]]) {
873     auto linearStepType = parser.getBuilder().getI32Type();
874     SmallVector<Type> linearStepTypes(linearSteps.size(), linearStepType);
875     if (failed(parser.resolveOperands(linears, linearTypes, linears[0].location,
876                                       result.operands)) ||
877         failed(parser.resolveOperands(linearSteps, linearStepTypes,
878                                       linearSteps[0].location,
879                                       result.operands)))
880       return failure();
881   }
882 
883   // Add schedule parameters
884   if (done[scheduleClause] && !schedule.empty()) {
885     schedule[0] = llvm::toUpper(schedule[0]);
886     if (Optional<ClauseScheduleKind> sched =
887             symbolizeClauseScheduleKind(schedule)) {
888       auto attr = ClauseScheduleKindAttr::get(parser.getContext(), *sched);
889       result.addAttribute("schedule_val", attr);
890     } else {
891       return parser.emitError(parser.getCurrentLocation(),
892                               "invalid schedule kind");
893     }
894     if (!modifiers.empty()) {
895       llvm::SMLoc loc = parser.getCurrentLocation();
896       if (Optional<ScheduleModifier> mod =
897               symbolizeScheduleModifier(modifiers[0])) {
898         result.addAttribute(
899             "schedule_modifier",
900             ScheduleModifierAttr::get(parser.getContext(), *mod));
901       } else {
902         return parser.emitError(loc, "invalid schedule modifier");
903       }
904       // Only SIMD attribute is allowed here!
905       if (modifiers.size() > 1) {
906         assert(symbolizeScheduleModifier(modifiers[1]) ==
907                ScheduleModifier::simd);
908         auto attr = UnitAttr::get(parser.getBuilder().getContext());
909         result.addAttribute("simd_modifier", attr);
910       }
911     }
912     if (scheduleChunkSize)
913       parser.resolveOperand(*scheduleChunkSize, scheduleChunkType,
914                             result.operands);
915   }
916 
917   segments.insert(segments.end(), clauseSegments.begin(), clauseSegments.end());
918 
919   return success();
920 }
921 
922 /// Parses a parallel operation.
923 ///
924 /// operation ::= `omp.parallel` clause-list
925 /// clause-list ::= clause | clause clause-list
926 /// clause ::= if | num-threads | private | firstprivate | shared | copyin |
927 ///            allocate | default | proc-bind
928 ///
929 static ParseResult parseParallelOp(OpAsmParser &parser,
930                                    OperationState &result) {
931   SmallVector<ClauseType> clauses = {
932       ifClause,           numThreadsClause, privateClause,
933       firstprivateClause, sharedClause,     copyinClause,
934       allocateClause,     defaultClause,    procBindClause};
935 
936   SmallVector<int> segments;
937 
938   if (failed(parseClauses(parser, result, clauses, segments)))
939     return failure();
940 
941   result.addAttribute("operand_segment_sizes",
942                       parser.getBuilder().getI32VectorAttr(segments));
943 
944   Region *body = result.addRegion();
945   SmallVector<OpAsmParser::OperandType> regionArgs;
946   SmallVector<Type> regionArgTypes;
947   if (parser.parseRegion(*body, regionArgs, regionArgTypes))
948     return failure();
949   return success();
950 }
951 
952 //===----------------------------------------------------------------------===//
953 // Parser, printer and verifier for SectionsOp
954 //===----------------------------------------------------------------------===//
955 
956 /// Parses an OpenMP Sections operation
957 ///
958 /// sections ::= `omp.sections` clause-list
959 /// clause-list ::= clause clause-list | empty
960 /// clause ::= private | firstprivate | lastprivate | reduction | allocate |
961 ///            nowait
962 static ParseResult parseSectionsOp(OpAsmParser &parser,
963                                    OperationState &result) {
964 
965   SmallVector<ClauseType> clauses = {privateClause,     firstprivateClause,
966                                      lastprivateClause, reductionClause,
967                                      allocateClause,    nowaitClause};
968 
969   SmallVector<int> segments;
970 
971   if (failed(parseClauses(parser, result, clauses, segments)))
972     return failure();
973 
974   result.addAttribute("operand_segment_sizes",
975                       parser.getBuilder().getI32VectorAttr(segments));
976 
977   // Now parse the body.
978   Region *body = result.addRegion();
979   if (parser.parseRegion(*body))
980     return failure();
981   return success();
982 }
983 
984 static void printSectionsOp(OpAsmPrinter &p, SectionsOp op) {
985   p << " ";
986   printDataVars(p, op.private_vars(), "private");
987   printDataVars(p, op.firstprivate_vars(), "firstprivate");
988   printDataVars(p, op.lastprivate_vars(), "lastprivate");
989 
990   if (!op.reduction_vars().empty())
991     printReductionVarList(p, op.reductions(), op.reduction_vars());
992 
993   if (!op.allocate_vars().empty())
994     printAllocateAndAllocator(p, op.allocate_vars(), op.allocators_vars());
995 
996   if (op.nowait())
997     p << "nowait";
998 
999   p << ' ';
1000   p.printRegion(op.region());
1001 }
1002 
1003 static LogicalResult verifySectionsOp(SectionsOp op) {
1004 
1005   // A list item may not appear in more than one clause on the same directive,
1006   // except that it may be specified in both firstprivate and lastprivate
1007   // clauses.
1008   for (auto var : op.private_vars()) {
1009     if (llvm::is_contained(op.firstprivate_vars(), var))
1010       return op.emitOpError()
1011              << "operand used in both private and firstprivate clauses";
1012     if (llvm::is_contained(op.lastprivate_vars(), var))
1013       return op.emitOpError()
1014              << "operand used in both private and lastprivate clauses";
1015   }
1016 
1017   if (op.allocate_vars().size() != op.allocators_vars().size())
1018     return op.emitError(
1019         "expected equal sizes for allocate and allocator variables");
1020 
1021   for (auto &inst : *op.region().begin()) {
1022     if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst)))
1023       op.emitOpError()
1024           << "expected omp.section op or terminator op inside region";
1025   }
1026 
1027   return verifyReductionVarList(op, op.reductions(), op.reduction_vars());
1028 }
1029 
1030 /// Parses an OpenMP Workshare Loop operation
1031 ///
1032 /// wsloop ::= `omp.wsloop` loop-control clause-list
1033 /// loop-control ::= `(` ssa-id-list `)` `:` type `=`  loop-bounds
1034 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` inclusive? steps
1035 /// steps := `step` `(`ssa-id-list`)`
1036 /// clause-list ::= clause clause-list | empty
1037 /// clause ::= private | firstprivate | lastprivate | linear | schedule |
1038 //             collapse | nowait | ordered | order | reduction
1039 static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) {
1040 
1041   // Parse an opening `(` followed by induction variables followed by `)`
1042   SmallVector<OpAsmParser::OperandType> ivs;
1043   if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
1044                                      OpAsmParser::Delimiter::Paren))
1045     return failure();
1046 
1047   int numIVs = static_cast<int>(ivs.size());
1048   Type loopVarType;
1049   if (parser.parseColonType(loopVarType))
1050     return failure();
1051 
1052   // Parse loop bounds.
1053   SmallVector<OpAsmParser::OperandType> lower;
1054   if (parser.parseEqual() ||
1055       parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) ||
1056       parser.resolveOperands(lower, loopVarType, result.operands))
1057     return failure();
1058 
1059   SmallVector<OpAsmParser::OperandType> upper;
1060   if (parser.parseKeyword("to") ||
1061       parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) ||
1062       parser.resolveOperands(upper, loopVarType, result.operands))
1063     return failure();
1064 
1065   if (succeeded(parser.parseOptionalKeyword("inclusive"))) {
1066     auto attr = UnitAttr::get(parser.getBuilder().getContext());
1067     result.addAttribute("inclusive", attr);
1068   }
1069 
1070   // Parse step values.
1071   SmallVector<OpAsmParser::OperandType> steps;
1072   if (parser.parseKeyword("step") ||
1073       parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) ||
1074       parser.resolveOperands(steps, loopVarType, result.operands))
1075     return failure();
1076 
1077   SmallVector<ClauseType> clauses = {
1078       privateClause,   firstprivateClause, lastprivateClause, linearClause,
1079       reductionClause, collapseClause,     orderClause,       orderedClause,
1080       nowaitClause,    scheduleClause};
1081   SmallVector<int> segments{numIVs, numIVs, numIVs};
1082   if (failed(parseClauses(parser, result, clauses, segments)))
1083     return failure();
1084 
1085   result.addAttribute("operand_segment_sizes",
1086                       parser.getBuilder().getI32VectorAttr(segments));
1087 
1088   // Now parse the body.
1089   Region *body = result.addRegion();
1090   SmallVector<Type> ivTypes(numIVs, loopVarType);
1091   SmallVector<OpAsmParser::OperandType> blockArgs(ivs);
1092   if (parser.parseRegion(*body, blockArgs, ivTypes))
1093     return failure();
1094   return success();
1095 }
1096 
1097 static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) {
1098   auto args = op.getRegion().front().getArguments();
1099   p << " (" << args << ") : " << args[0].getType() << " = (" << op.lowerBound()
1100     << ") to (" << op.upperBound() << ") ";
1101   if (op.inclusive()) {
1102     p << "inclusive ";
1103   }
1104   p << "step (" << op.step() << ") ";
1105 
1106   printDataVars(p, op.private_vars(), "private");
1107   printDataVars(p, op.firstprivate_vars(), "firstprivate");
1108   printDataVars(p, op.lastprivate_vars(), "lastprivate");
1109 
1110   if (!op.linear_vars().empty())
1111     printLinearClause(p, op.linear_vars(), op.linear_step_vars());
1112 
1113   if (auto sched = op.schedule_val())
1114     printScheduleClause(p, sched.getValue(), op.schedule_modifier(),
1115                         op.simd_modifier(), op.schedule_chunk_var());
1116 
1117   if (auto collapse = op.collapse_val())
1118     p << "collapse(" << collapse << ") ";
1119 
1120   if (op.nowait())
1121     p << "nowait ";
1122 
1123   if (auto ordered = op.ordered_val())
1124     p << "ordered(" << ordered << ") ";
1125 
1126   if (auto order = op.order_val())
1127     p << "order(" << stringifyClauseOrderKind(*order) << ") ";
1128 
1129   if (!op.reduction_vars().empty())
1130     printReductionVarList(p, op.reductions(), op.reduction_vars());
1131 
1132   p << ' ';
1133   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
1134 }
1135 
1136 //===----------------------------------------------------------------------===//
1137 // ReductionOp
1138 //===----------------------------------------------------------------------===//
1139 
1140 static ParseResult parseAtomicReductionRegion(OpAsmParser &parser,
1141                                               Region &region) {
1142   if (parser.parseOptionalKeyword("atomic"))
1143     return success();
1144   return parser.parseRegion(region);
1145 }
1146 
1147 static void printAtomicReductionRegion(OpAsmPrinter &printer,
1148                                        ReductionDeclareOp op, Region &region) {
1149   if (region.empty())
1150     return;
1151   printer << "atomic ";
1152   printer.printRegion(region);
1153 }
1154 
1155 static LogicalResult verifyReductionDeclareOp(ReductionDeclareOp op) {
1156   if (op.initializerRegion().empty())
1157     return op.emitOpError() << "expects non-empty initializer region";
1158   Block &initializerEntryBlock = op.initializerRegion().front();
1159   if (initializerEntryBlock.getNumArguments() != 1 ||
1160       initializerEntryBlock.getArgument(0).getType() != op.type()) {
1161     return op.emitOpError() << "expects initializer region with one argument "
1162                                "of the reduction type";
1163   }
1164 
1165   for (YieldOp yieldOp : op.initializerRegion().getOps<YieldOp>()) {
1166     if (yieldOp.results().size() != 1 ||
1167         yieldOp.results().getTypes()[0] != op.type())
1168       return op.emitOpError() << "expects initializer region to yield a value "
1169                                  "of the reduction type";
1170   }
1171 
1172   if (op.reductionRegion().empty())
1173     return op.emitOpError() << "expects non-empty reduction region";
1174   Block &reductionEntryBlock = op.reductionRegion().front();
1175   if (reductionEntryBlock.getNumArguments() != 2 ||
1176       reductionEntryBlock.getArgumentTypes()[0] !=
1177           reductionEntryBlock.getArgumentTypes()[1] ||
1178       reductionEntryBlock.getArgumentTypes()[0] != op.type())
1179     return op.emitOpError() << "expects reduction region with two arguments of "
1180                                "the reduction type";
1181   for (YieldOp yieldOp : op.reductionRegion().getOps<YieldOp>()) {
1182     if (yieldOp.results().size() != 1 ||
1183         yieldOp.results().getTypes()[0] != op.type())
1184       return op.emitOpError() << "expects reduction region to yield a value "
1185                                  "of the reduction type";
1186   }
1187 
1188   if (op.atomicReductionRegion().empty())
1189     return success();
1190 
1191   Block &atomicReductionEntryBlock = op.atomicReductionRegion().front();
1192   if (atomicReductionEntryBlock.getNumArguments() != 2 ||
1193       atomicReductionEntryBlock.getArgumentTypes()[0] !=
1194           atomicReductionEntryBlock.getArgumentTypes()[1])
1195     return op.emitOpError() << "expects atomic reduction region with two "
1196                                "arguments of the same type";
1197   auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0]
1198                      .dyn_cast<PointerLikeType>();
1199   if (!ptrType || ptrType.getElementType() != op.type())
1200     return op.emitOpError() << "expects atomic reduction region arguments to "
1201                                "be accumulators containing the reduction type";
1202   return success();
1203 }
1204 
1205 static LogicalResult verifyReductionOp(ReductionOp op) {
1206   // TODO: generalize this to an op interface when there is more than one op
1207   // that supports reductions.
1208   auto container = op->getParentOfType<WsLoopOp>();
1209   for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i)
1210     if (container.reduction_vars()[i] == op.accumulator())
1211       return success();
1212 
1213   return op.emitOpError() << "the accumulator is not used by the parent";
1214 }
1215 
1216 //===----------------------------------------------------------------------===//
1217 // WsLoopOp
1218 //===----------------------------------------------------------------------===//
1219 
1220 void WsLoopOp::build(OpBuilder &builder, OperationState &state,
1221                      ValueRange lowerBound, ValueRange upperBound,
1222                      ValueRange step, ArrayRef<NamedAttribute> attributes) {
1223   build(builder, state, TypeRange(), lowerBound, upperBound, step,
1224         /*privateVars=*/ValueRange(),
1225         /*firstprivateVars=*/ValueRange(), /*lastprivate_vars=*/ValueRange(),
1226         /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
1227         /*reduction_vars=*/ValueRange(), /*schedule_val=*/nullptr,
1228         /*schedule_chunk_var=*/nullptr, /*collapse_val=*/nullptr,
1229         /*nowait=*/nullptr, /*ordered_val=*/nullptr, /*order_val=*/nullptr,
1230         /*inclusive=*/nullptr, /*buildBody=*/false);
1231   state.addAttributes(attributes);
1232 }
1233 
1234 void WsLoopOp::build(OpBuilder &, OperationState &state, TypeRange resultTypes,
1235                      ValueRange operands, ArrayRef<NamedAttribute> attributes) {
1236   state.addOperands(operands);
1237   state.addAttributes(attributes);
1238   (void)state.addRegion();
1239   assert(resultTypes.empty() && "mismatched number of return types");
1240   state.addTypes(resultTypes);
1241 }
1242 
1243 void WsLoopOp::build(OpBuilder &builder, OperationState &result,
1244                      TypeRange typeRange, ValueRange lowerBounds,
1245                      ValueRange upperBounds, ValueRange steps,
1246                      ValueRange privateVars, ValueRange firstprivateVars,
1247                      ValueRange lastprivateVars, ValueRange linearVars,
1248                      ValueRange linearStepVars, ValueRange reductionVars,
1249                      StringAttr scheduleVal, Value scheduleChunkVar,
1250                      IntegerAttr collapseVal, UnitAttr nowait,
1251                      IntegerAttr orderedVal, StringAttr orderVal,
1252                      UnitAttr inclusive, bool buildBody) {
1253   result.addOperands(lowerBounds);
1254   result.addOperands(upperBounds);
1255   result.addOperands(steps);
1256   result.addOperands(privateVars);
1257   result.addOperands(firstprivateVars);
1258   result.addOperands(linearVars);
1259   result.addOperands(linearStepVars);
1260   if (scheduleChunkVar)
1261     result.addOperands(scheduleChunkVar);
1262 
1263   if (scheduleVal)
1264     result.addAttribute("schedule_val", scheduleVal);
1265   if (collapseVal)
1266     result.addAttribute("collapse_val", collapseVal);
1267   if (nowait)
1268     result.addAttribute("nowait", nowait);
1269   if (orderedVal)
1270     result.addAttribute("ordered_val", orderedVal);
1271   if (orderVal)
1272     result.addAttribute("order", orderVal);
1273   if (inclusive)
1274     result.addAttribute("inclusive", inclusive);
1275   result.addAttribute(
1276       WsLoopOp::getOperandSegmentSizeAttr(),
1277       builder.getI32VectorAttr(
1278           {static_cast<int32_t>(lowerBounds.size()),
1279            static_cast<int32_t>(upperBounds.size()),
1280            static_cast<int32_t>(steps.size()),
1281            static_cast<int32_t>(privateVars.size()),
1282            static_cast<int32_t>(firstprivateVars.size()),
1283            static_cast<int32_t>(lastprivateVars.size()),
1284            static_cast<int32_t>(linearVars.size()),
1285            static_cast<int32_t>(linearStepVars.size()),
1286            static_cast<int32_t>(reductionVars.size()),
1287            static_cast<int32_t>(scheduleChunkVar != nullptr ? 1 : 0)}));
1288 
1289   Region *bodyRegion = result.addRegion();
1290   if (buildBody) {
1291     OpBuilder::InsertionGuard guard(builder);
1292     unsigned numIVs = steps.size();
1293     SmallVector<Type, 8> argTypes(numIVs, steps.getType().front());
1294     SmallVector<Location, 8> argLocs(numIVs, result.location);
1295     builder.createBlock(bodyRegion, {}, argTypes, argLocs);
1296   }
1297 }
1298 
1299 static LogicalResult verifyWsLoopOp(WsLoopOp op) {
1300   return verifyReductionVarList(op, op.reductions(), op.reduction_vars());
1301 }
1302 
1303 //===----------------------------------------------------------------------===//
1304 // Verifier for critical construct (2.17.1)
1305 //===----------------------------------------------------------------------===//
1306 
1307 static LogicalResult verifyCriticalDeclareOp(CriticalDeclareOp op) {
1308   return verifySynchronizationHint(op, op.hint());
1309 }
1310 
1311 static LogicalResult verifyCriticalOp(CriticalOp op) {
1312 
1313   if (op.nameAttr()) {
1314     auto symbolRef = op.nameAttr().cast<SymbolRefAttr>();
1315     auto decl =
1316         SymbolTable::lookupNearestSymbolFrom<CriticalDeclareOp>(op, symbolRef);
1317     if (!decl) {
1318       return op.emitOpError() << "expected symbol reference " << symbolRef
1319                               << " to point to a critical declaration";
1320     }
1321   }
1322 
1323   return success();
1324 }
1325 
1326 //===----------------------------------------------------------------------===//
1327 // Verifier for ordered construct
1328 //===----------------------------------------------------------------------===//
1329 
1330 static LogicalResult verifyOrderedOp(OrderedOp op) {
1331   auto container = op->getParentOfType<WsLoopOp>();
1332   if (!container || !container.ordered_valAttr() ||
1333       container.ordered_valAttr().getInt() == 0)
1334     return op.emitOpError() << "ordered depend directive must be closely "
1335                             << "nested inside a worksharing-loop with ordered "
1336                             << "clause with parameter present";
1337 
1338   if (container.ordered_valAttr().getInt() !=
1339       (int64_t)op.num_loops_val().getValue())
1340     return op.emitOpError() << "number of variables in depend clause does not "
1341                             << "match number of iteration variables in the "
1342                             << "doacross loop";
1343 
1344   return success();
1345 }
1346 
1347 static LogicalResult verifyOrderedRegionOp(OrderedRegionOp op) {
1348   // TODO: The code generation for ordered simd directive is not supported yet.
1349   if (op.simd())
1350     return failure();
1351 
1352   if (auto container = op->getParentOfType<WsLoopOp>()) {
1353     if (!container.ordered_valAttr() ||
1354         container.ordered_valAttr().getInt() != 0)
1355       return op.emitOpError() << "ordered region must be closely nested inside "
1356                               << "a worksharing-loop region with an ordered "
1357                               << "clause without parameter present";
1358   }
1359 
1360   return success();
1361 }
1362 
1363 //===----------------------------------------------------------------------===//
1364 // AtomicReadOp
1365 //===----------------------------------------------------------------------===//
1366 
1367 /// Parser for AtomicReadOp
1368 ///
1369 /// operation ::= `omp.atomic.read` atomic-clause-list address `->` result-type
1370 /// address ::= operand `:` type
1371 static ParseResult parseAtomicReadOp(OpAsmParser &parser,
1372                                      OperationState &result) {
1373   OpAsmParser::OperandType x, v;
1374   Type addressType;
1375   SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
1376   SmallVector<int> segments;
1377 
1378   if (parser.parseOperand(v) || parser.parseEqual() || parser.parseOperand(x) ||
1379       parseClauses(parser, result, clauses, segments) ||
1380       parser.parseColonType(addressType) ||
1381       parser.resolveOperand(x, addressType, result.operands) ||
1382       parser.resolveOperand(v, addressType, result.operands))
1383     return failure();
1384 
1385   return success();
1386 }
1387 
1388 /// Printer for AtomicReadOp
1389 static void printAtomicReadOp(OpAsmPrinter &p, AtomicReadOp op) {
1390   p << " " << op.v() << " = " << op.x() << " ";
1391   if (auto mo = op.memory_order())
1392     p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") ";
1393   if (op.hintAttr())
1394     printSynchronizationHint(p << " ", op, op.hintAttr());
1395   p << ": " << op.x().getType();
1396 }
1397 
1398 /// Verifier for AtomicReadOp
1399 static LogicalResult verifyAtomicReadOp(AtomicReadOp op) {
1400   if (auto mo = op.memory_order()) {
1401     if (*mo == ClauseMemoryOrderKind::acq_rel ||
1402         *mo == ClauseMemoryOrderKind::release) {
1403       return op.emitError(
1404           "memory-order must not be acq_rel or release for atomic reads");
1405     }
1406   }
1407   if (op.x() == op.v())
1408     return op.emitError(
1409         "read and write must not be to the same location for atomic reads");
1410   return verifySynchronizationHint(op, op.hint());
1411 }
1412 
1413 //===----------------------------------------------------------------------===//
1414 // AtomicWriteOp
1415 //===----------------------------------------------------------------------===//
1416 
1417 /// Parser for AtomicWriteOp
1418 ///
1419 /// operation ::= `omp.atomic.write` atomic-clause-list operands
1420 /// operands ::= address `,` value
1421 /// address ::= operand `:` type
1422 /// value ::= operand `:` type
1423 static ParseResult parseAtomicWriteOp(OpAsmParser &parser,
1424                                       OperationState &result) {
1425   OpAsmParser::OperandType address, value;
1426   Type addrType, valueType;
1427   SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
1428   SmallVector<int> segments;
1429 
1430   if (parser.parseOperand(address) || parser.parseEqual() ||
1431       parser.parseOperand(value) ||
1432       parseClauses(parser, result, clauses, segments) ||
1433       parser.parseColonType(addrType) || parser.parseComma() ||
1434       parser.parseType(valueType) ||
1435       parser.resolveOperand(address, addrType, result.operands) ||
1436       parser.resolveOperand(value, valueType, result.operands))
1437     return failure();
1438   return success();
1439 }
1440 
1441 /// Printer for AtomicWriteOp
1442 static void printAtomicWriteOp(OpAsmPrinter &p, AtomicWriteOp op) {
1443   p << " " << op.address() << " = " << op.value() << " ";
1444   if (auto mo = op.memory_order())
1445     p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") ";
1446   if (op.hintAttr())
1447     printSynchronizationHint(p, op, op.hintAttr());
1448   p << ": " << op.address().getType() << ", " << op.value().getType();
1449 }
1450 
1451 /// Verifier for AtomicWriteOp
1452 static LogicalResult verifyAtomicWriteOp(AtomicWriteOp op) {
1453   if (auto mo = op.memory_order()) {
1454     if (*mo == ClauseMemoryOrderKind::acq_rel ||
1455         *mo == ClauseMemoryOrderKind::acquire) {
1456       return op.emitError(
1457           "memory-order must not be acq_rel or acquire for atomic writes");
1458     }
1459   }
1460   return verifySynchronizationHint(op, op.hint());
1461 }
1462 
1463 //===----------------------------------------------------------------------===//
1464 // AtomicUpdateOp
1465 //===----------------------------------------------------------------------===//
1466 
1467 /// Parser for AtomicUpdateOp
1468 ///
1469 /// operation ::= `omp.atomic.update` atomic-clause-list region
1470 static ParseResult parseAtomicUpdateOp(OpAsmParser &parser,
1471                                        OperationState &result) {
1472   SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
1473   SmallVector<int> segments;
1474   OpAsmParser::OperandType x, y, z;
1475   Type xType, exprType;
1476   StringRef binOp;
1477 
1478   // x = y `op` z : xtype, exprtype
1479   if (parser.parseOperand(x) || parser.parseEqual() || parser.parseOperand(y) ||
1480       parser.parseKeyword(&binOp) || parser.parseOperand(z) ||
1481       parseClauses(parser, result, clauses, segments) || parser.parseColon() ||
1482       parser.parseType(xType) || parser.parseComma() ||
1483       parser.parseType(exprType) ||
1484       parser.resolveOperand(x, xType, result.operands)) {
1485     return failure();
1486   }
1487 
1488   auto binOpEnum = AtomicBinOpKindToEnum(binOp.upper());
1489   if (!binOpEnum)
1490     return parser.emitError(parser.getNameLoc())
1491            << "invalid atomic bin op in atomic update\n";
1492   auto attr =
1493       parser.getBuilder().getI64IntegerAttr((int64_t)binOpEnum.getValue());
1494   result.addAttribute("binop", attr);
1495 
1496   OpAsmParser::OperandType expr;
1497   if (x.name == y.name && x.number == y.number) {
1498     expr = z;
1499     result.addAttribute("isXBinopExpr", parser.getBuilder().getUnitAttr());
1500   } else if (x.name == z.name && x.number == z.number) {
1501     expr = y;
1502   } else {
1503     return parser.emitError(parser.getNameLoc())
1504            << "atomic update variable " << x.name
1505            << " not found in the RHS of the assignment statement in an"
1506               " atomic.update operation";
1507   }
1508   return parser.resolveOperand(expr, exprType, result.operands);
1509 }
1510 
1511 /// Printer for AtomicUpdateOp
1512 static void printAtomicUpdateOp(OpAsmPrinter &p, AtomicUpdateOp op) {
1513   p << " " << op.x() << " = ";
1514   Value y, z;
1515   if (op.isXBinopExpr()) {
1516     y = op.x();
1517     z = op.expr();
1518   } else {
1519     y = op.expr();
1520     z = op.x();
1521   }
1522   p << y << " " << AtomicBinOpKindToString(op.binop()).lower() << " " << z
1523     << " ";
1524   if (auto mo = op.memory_order())
1525     p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") ";
1526   if (op.hintAttr())
1527     printSynchronizationHint(p, op, op.hintAttr());
1528   p << ": " << op.x().getType() << ", " << op.expr().getType();
1529 }
1530 
1531 /// Verifier for AtomicUpdateOp
1532 static LogicalResult verifyAtomicUpdateOp(AtomicUpdateOp op) {
1533   if (auto mo = op.memory_order()) {
1534     if (*mo == ClauseMemoryOrderKind::acq_rel ||
1535         *mo == ClauseMemoryOrderKind::acquire) {
1536       return op.emitError(
1537           "memory-order must not be acq_rel or acquire for atomic updates");
1538     }
1539   }
1540   return success();
1541 }
1542 
1543 #define GET_ATTRDEF_CLASSES
1544 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
1545 
1546 #define GET_OP_CLASSES
1547 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
1548