1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the OpenMP dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/OpImplementation.h"
18 #include "mlir/IR/OperationSupport.h"
19 
20 #include "llvm/ADT/BitVector.h"
21 #include "llvm/ADT/SmallString.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/ADT/StringSwitch.h"
25 #include <cstddef>
26 
27 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
28 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
29 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
30 
31 using namespace mlir;
32 using namespace mlir::omp;
33 
34 namespace {
35 /// Model for pointer-like types that already provide a `getElementType` method.
36 template <typename T>
37 struct PointerLikeModel
38     : public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> {
39   Type getElementType(Type pointer) const {
40     return pointer.cast<T>().getElementType();
41   }
42 };
43 } // namespace
44 
45 void OpenMPDialect::initialize() {
46   addOperations<
47 #define GET_OP_LIST
48 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
49       >();
50 
51   LLVM::LLVMPointerType::attachInterface<
52       PointerLikeModel<LLVM::LLVMPointerType>>(*getContext());
53   MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext());
54 }
55 
56 //===----------------------------------------------------------------------===//
57 // ParallelOp
58 //===----------------------------------------------------------------------===//
59 
60 void ParallelOp::build(OpBuilder &builder, OperationState &state,
61                        ArrayRef<NamedAttribute> attributes) {
62   ParallelOp::build(
63       builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
64       /*default_val=*/nullptr, /*private_vars=*/ValueRange(),
65       /*firstprivate_vars=*/ValueRange(), /*shared_vars=*/ValueRange(),
66       /*copyin_vars=*/ValueRange(), /*allocate_vars=*/ValueRange(),
67       /*allocators_vars=*/ValueRange(), /*proc_bind_val=*/nullptr);
68   state.addAttributes(attributes);
69 }
70 
71 //===----------------------------------------------------------------------===//
72 // Parser and printer for Operand and type list
73 //===----------------------------------------------------------------------===//
74 
75 /// Parse a list of operands with types.
76 ///
77 /// operand-and-type-list ::= `(` ssa-id-and-type-list `)`
78 /// ssa-id-and-type-list ::= ssa-id-and-type |
79 ///                          ssa-id-and-type `,` ssa-id-and-type-list
80 /// ssa-id-and-type ::= ssa-id `:` type
81 static ParseResult
82 parseOperandAndTypeList(OpAsmParser &parser,
83                         SmallVectorImpl<OpAsmParser::OperandType> &operands,
84                         SmallVectorImpl<Type> &types) {
85   return parser.parseCommaSeparatedList(
86       OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
87         OpAsmParser::OperandType operand;
88         Type type;
89         if (parser.parseOperand(operand) || parser.parseColonType(type))
90           return failure();
91         operands.push_back(operand);
92         types.push_back(type);
93         return success();
94       });
95 }
96 
97 /// Print an operand and type list with parentheses
98 static void printOperandAndTypeList(OpAsmPrinter &p, OperandRange operands) {
99   p << "(";
100   llvm::interleaveComma(
101       operands, p, [&](const Value &v) { p << v << " : " << v.getType(); });
102   p << ") ";
103 }
104 
105 /// Print data variables corresponding to a data-sharing clause `name`
106 static void printDataVars(OpAsmPrinter &p, OperandRange operands,
107                           StringRef name) {
108   if (!operands.empty()) {
109     p << name;
110     printOperandAndTypeList(p, operands);
111   }
112 }
113 
114 //===----------------------------------------------------------------------===//
115 // Parser and printer for Allocate Clause
116 //===----------------------------------------------------------------------===//
117 
118 /// Parse an allocate clause with allocators and a list of operands with types.
119 ///
120 /// allocate ::= `allocate` `(` allocate-operand-list `)`
121 /// allocate-operand-list :: = allocate-operand |
122 ///                            allocator-operand `,` allocate-operand-list
123 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
124 /// ssa-id-and-type ::= ssa-id `:` type
125 static ParseResult parseAllocateAndAllocator(
126     OpAsmParser &parser,
127     SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate,
128     SmallVectorImpl<Type> &typesAllocate,
129     SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator,
130     SmallVectorImpl<Type> &typesAllocator) {
131 
132   return parser.parseCommaSeparatedList(
133       OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
134         OpAsmParser::OperandType operand;
135         Type type;
136         if (parser.parseOperand(operand) || parser.parseColonType(type))
137           return failure();
138         operandsAllocator.push_back(operand);
139         typesAllocator.push_back(type);
140         if (parser.parseArrow())
141           return failure();
142         if (parser.parseOperand(operand) || parser.parseColonType(type))
143           return failure();
144 
145         operandsAllocate.push_back(operand);
146         typesAllocate.push_back(type);
147         return success();
148       });
149 }
150 
151 /// Print allocate clause
152 static void printAllocateAndAllocator(OpAsmPrinter &p,
153                                       OperandRange varsAllocate,
154                                       OperandRange varsAllocator) {
155   p << "allocate(";
156   for (unsigned i = 0; i < varsAllocate.size(); ++i) {
157     std::string separator = i == varsAllocate.size() - 1 ? ") " : ", ";
158     p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> ";
159     p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator;
160   }
161 }
162 
163 static LogicalResult verifyParallelOp(ParallelOp op) {
164   if (op.allocate_vars().size() != op.allocators_vars().size())
165     return op.emitError(
166         "expected equal sizes for allocate and allocator variables");
167   return success();
168 }
169 
170 static void printParallelOp(OpAsmPrinter &p, ParallelOp op) {
171   p << " ";
172   if (auto ifCond = op.if_expr_var())
173     p << "if(" << ifCond << " : " << ifCond.getType() << ") ";
174 
175   if (auto threads = op.num_threads_var())
176     p << "num_threads(" << threads << " : " << threads.getType() << ") ";
177 
178   printDataVars(p, op.private_vars(), "private");
179   printDataVars(p, op.firstprivate_vars(), "firstprivate");
180   printDataVars(p, op.shared_vars(), "shared");
181   printDataVars(p, op.copyin_vars(), "copyin");
182 
183   if (!op.allocate_vars().empty())
184     printAllocateAndAllocator(p, op.allocate_vars(), op.allocators_vars());
185 
186   if (auto def = op.default_val())
187     p << "default(" << def->drop_front(3) << ") ";
188 
189   if (auto bind = op.proc_bind_val())
190     p << "proc_bind(" << bind << ") ";
191 
192   p.printRegion(op.getRegion());
193 }
194 
195 //===----------------------------------------------------------------------===//
196 // Parser and printer for Linear Clause
197 //===----------------------------------------------------------------------===//
198 
199 /// linear ::= `linear` `(` linear-list `)`
200 /// linear-list := linear-val | linear-val linear-list
201 /// linear-val := ssa-id-and-type `=` ssa-id-and-type
202 static ParseResult
203 parseLinearClause(OpAsmParser &parser,
204                   SmallVectorImpl<OpAsmParser::OperandType> &vars,
205                   SmallVectorImpl<Type> &types,
206                   SmallVectorImpl<OpAsmParser::OperandType> &stepVars) {
207   if (parser.parseLParen())
208     return failure();
209 
210   do {
211     OpAsmParser::OperandType var;
212     Type type;
213     OpAsmParser::OperandType stepVar;
214     if (parser.parseOperand(var) || parser.parseEqual() ||
215         parser.parseOperand(stepVar) || parser.parseColonType(type))
216       return failure();
217 
218     vars.push_back(var);
219     types.push_back(type);
220     stepVars.push_back(stepVar);
221   } while (succeeded(parser.parseOptionalComma()));
222 
223   if (parser.parseRParen())
224     return failure();
225 
226   return success();
227 }
228 
229 /// Print Linear Clause
230 static void printLinearClause(OpAsmPrinter &p, OperandRange linearVars,
231                               OperandRange linearStepVars) {
232   size_t linearVarsSize = linearVars.size();
233   p << "linear(";
234   for (unsigned i = 0; i < linearVarsSize; ++i) {
235     std::string separator = i == linearVarsSize - 1 ? ") " : ", ";
236     p << linearVars[i];
237     if (linearStepVars.size() > i)
238       p << " = " << linearStepVars[i];
239     p << " : " << linearVars[i].getType() << separator;
240   }
241 }
242 
243 //===----------------------------------------------------------------------===//
244 // Parser and printer for Schedule Clause
245 //===----------------------------------------------------------------------===//
246 
247 static ParseResult
248 verifyScheduleModifiers(OpAsmParser &parser,
249                         SmallVectorImpl<SmallString<12>> &modifiers) {
250   if (modifiers.size() > 2)
251     return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)";
252   for (const auto &mod : modifiers) {
253     // Translate the string. If it has no value, then it was not a valid
254     // modifier!
255     auto symbol = symbolizeScheduleModifier(mod);
256     if (!symbol.hasValue())
257       return parser.emitError(parser.getNameLoc())
258              << " unknown modifier type: " << mod;
259   }
260 
261   // If we have one modifier that is "simd", then stick a "none" modiifer in
262   // index 0.
263   if (modifiers.size() == 1) {
264     if (symbolizeScheduleModifier(modifiers[0]) ==
265         mlir::omp::ScheduleModifier::simd) {
266       modifiers.push_back(modifiers[0]);
267       modifiers[0] =
268           stringifyScheduleModifier(mlir::omp::ScheduleModifier::none);
269     }
270   } else if (modifiers.size() == 2) {
271     // If there are two modifier:
272     // First modifier should not be simd, second one should be simd
273     if (symbolizeScheduleModifier(modifiers[0]) ==
274             mlir::omp::ScheduleModifier::simd ||
275         symbolizeScheduleModifier(modifiers[1]) !=
276             mlir::omp::ScheduleModifier::simd)
277       return parser.emitError(parser.getNameLoc())
278              << " incorrect modifier order";
279   }
280   return success();
281 }
282 
283 /// schedule ::= `schedule` `(` sched-list `)`
284 /// sched-list ::= sched-val | sched-val sched-list |
285 ///                sched-val `,` sched-modifier
286 /// sched-val ::= sched-with-chunk | sched-wo-chunk
287 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
288 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
289 /// sched-wo-chunk ::=  `auto` | `runtime`
290 /// sched-modifier ::=  sched-mod-val | sched-mod-val `,` sched-mod-val
291 /// sched-mod-val ::=  `monotonic` | `nonmonotonic` | `simd` | `none`
292 static ParseResult
293 parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule,
294                     SmallVectorImpl<SmallString<12>> &modifiers,
295                     Optional<OpAsmParser::OperandType> &chunkSize) {
296   if (parser.parseLParen())
297     return failure();
298 
299   StringRef keyword;
300   if (parser.parseKeyword(&keyword))
301     return failure();
302 
303   schedule = keyword;
304   if (keyword == "static" || keyword == "dynamic" || keyword == "guided") {
305     if (succeeded(parser.parseOptionalEqual())) {
306       chunkSize = OpAsmParser::OperandType{};
307       if (parser.parseOperand(*chunkSize))
308         return failure();
309     } else {
310       chunkSize = llvm::NoneType::None;
311     }
312   } else if (keyword == "auto" || keyword == "runtime") {
313     chunkSize = llvm::NoneType::None;
314   } else {
315     return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
316   }
317 
318   // If there is a comma, we have one or more modifiers..
319   while (succeeded(parser.parseOptionalComma())) {
320     StringRef mod;
321     if (parser.parseKeyword(&mod))
322       return failure();
323     modifiers.push_back(mod);
324   }
325 
326   if (parser.parseRParen())
327     return failure();
328 
329   if (verifyScheduleModifiers(parser, modifiers))
330     return failure();
331 
332   return success();
333 }
334 
335 /// Print schedule clause
336 static void printScheduleClause(OpAsmPrinter &p, StringRef &sched,
337                                 llvm::Optional<StringRef> modifier, bool simd,
338                                 Value scheduleChunkVar) {
339   std::string schedLower = sched.lower();
340   p << "schedule(" << schedLower;
341   if (scheduleChunkVar)
342     p << " = " << scheduleChunkVar;
343   if (modifier && modifier.hasValue())
344     p << ", " << modifier;
345   if (simd)
346     p << ", simd";
347   p << ") ";
348 }
349 
350 //===----------------------------------------------------------------------===//
351 // Parser, printer and verifier for ReductionVarList
352 //===----------------------------------------------------------------------===//
353 
354 /// reduction ::= `reduction` `(` reduction-entry-list `)`
355 /// reduction-entry-list ::= reduction-entry
356 ///                        | reduction-entry-list `,` reduction-entry
357 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type
358 static ParseResult
359 parseReductionVarList(OpAsmParser &parser,
360                       SmallVectorImpl<SymbolRefAttr> &symbols,
361                       SmallVectorImpl<OpAsmParser::OperandType> &operands,
362                       SmallVectorImpl<Type> &types) {
363   if (failed(parser.parseLParen()))
364     return failure();
365 
366   do {
367     if (parser.parseAttribute(symbols.emplace_back()) || parser.parseArrow() ||
368         parser.parseOperand(operands.emplace_back()) ||
369         parser.parseColonType(types.emplace_back()))
370       return failure();
371   } while (succeeded(parser.parseOptionalComma()));
372   return parser.parseRParen();
373 }
374 
375 /// Print Reduction clause
376 static void printReductionVarList(OpAsmPrinter &p,
377                                   Optional<ArrayAttr> reductions,
378                                   OperandRange reductionVars) {
379   p << "reduction(";
380   for (unsigned i = 0, e = reductions->size(); i < e; ++i) {
381     if (i != 0)
382       p << ", ";
383     p << (*reductions)[i] << " -> " << reductionVars[i] << " : "
384       << reductionVars[i].getType();
385   }
386   p << ") ";
387 }
388 
389 /// Verifies Reduction Clause
390 static LogicalResult verifyReductionVarList(Operation *op,
391                                             Optional<ArrayAttr> reductions,
392                                             OperandRange reductionVars) {
393   if (!reductionVars.empty()) {
394     if (!reductions || reductions->size() != reductionVars.size())
395       return op->emitOpError()
396              << "expected as many reduction symbol references "
397                 "as reduction variables";
398   } else {
399     if (reductions)
400       return op->emitOpError() << "unexpected reduction symbol references";
401     return success();
402   }
403 
404   DenseSet<Value> accumulators;
405   for (auto args : llvm::zip(reductionVars, *reductions)) {
406     Value accum = std::get<0>(args);
407 
408     if (!accumulators.insert(accum).second)
409       return op->emitOpError() << "accumulator variable used more than once";
410 
411     Type varType = accum.getType().cast<PointerLikeType>();
412     auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>();
413     auto decl =
414         SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef);
415     if (!decl)
416       return op->emitOpError() << "expected symbol reference " << symbolRef
417                                << " to point to a reduction declaration";
418 
419     if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
420       return op->emitOpError()
421              << "expected accumulator (" << varType
422              << ") to be the same type as reduction declaration ("
423              << decl.getAccumulatorType() << ")";
424   }
425 
426   return success();
427 }
428 
429 //===----------------------------------------------------------------------===//
430 // Parser, printer and verifier for Synchronization Hint (2.17.12)
431 //===----------------------------------------------------------------------===//
432 
433 /// Parses a Synchronization Hint clause. The value of hint is an integer
434 /// which is a combination of different hints from `omp_sync_hint_t`.
435 ///
436 /// hint-clause = `hint` `(` hint-value `)`
437 static ParseResult parseSynchronizationHint(OpAsmParser &parser,
438                                             IntegerAttr &hintAttr,
439                                             bool parseKeyword = true) {
440   if (parseKeyword && failed(parser.parseOptionalKeyword("hint"))) {
441     hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
442     return success();
443   }
444 
445   if (failed(parser.parseLParen()))
446     return failure();
447   StringRef hintKeyword;
448   int64_t hint = 0;
449   do {
450     if (failed(parser.parseKeyword(&hintKeyword)))
451       return failure();
452     if (hintKeyword == "uncontended")
453       hint |= 1;
454     else if (hintKeyword == "contended")
455       hint |= 2;
456     else if (hintKeyword == "nonspeculative")
457       hint |= 4;
458     else if (hintKeyword == "speculative")
459       hint |= 8;
460     else
461       return parser.emitError(parser.getCurrentLocation())
462              << hintKeyword << " is not a valid hint";
463   } while (succeeded(parser.parseOptionalComma()));
464   if (failed(parser.parseRParen()))
465     return failure();
466   hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
467   return success();
468 }
469 
470 /// Prints a Synchronization Hint clause
471 static void printSynchronizationHint(OpAsmPrinter &p, Operation *op,
472                                      IntegerAttr hintAttr) {
473   int64_t hint = hintAttr.getInt();
474 
475   if (hint == 0)
476     return;
477 
478   // Helper function to get n-th bit from the right end of `value`
479   auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
480 
481   bool uncontended = bitn(hint, 0);
482   bool contended = bitn(hint, 1);
483   bool nonspeculative = bitn(hint, 2);
484   bool speculative = bitn(hint, 3);
485 
486   SmallVector<StringRef> hints;
487   if (uncontended)
488     hints.push_back("uncontended");
489   if (contended)
490     hints.push_back("contended");
491   if (nonspeculative)
492     hints.push_back("nonspeculative");
493   if (speculative)
494     hints.push_back("speculative");
495 
496   p << "hint(";
497   llvm::interleaveComma(hints, p);
498   p << ") ";
499 }
500 
501 /// Verifies a synchronization hint clause
502 static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
503 
504   // Helper function to get n-th bit from the right end of `value`
505   auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
506 
507   bool uncontended = bitn(hint, 0);
508   bool contended = bitn(hint, 1);
509   bool nonspeculative = bitn(hint, 2);
510   bool speculative = bitn(hint, 3);
511 
512   if (uncontended && contended)
513     return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
514                                 "omp_sync_hint_contended cannot be combined";
515   if (nonspeculative && speculative)
516     return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
517                                 "omp_sync_hint_speculative cannot be combined.";
518   return success();
519 }
520 
521 enum ClauseType {
522   ifClause,
523   numThreadsClause,
524   privateClause,
525   firstprivateClause,
526   lastprivateClause,
527   sharedClause,
528   copyinClause,
529   allocateClause,
530   defaultClause,
531   procBindClause,
532   reductionClause,
533   nowaitClause,
534   linearClause,
535   scheduleClause,
536   collapseClause,
537   orderClause,
538   orderedClause,
539   memoryOrderClause,
540   hintClause,
541   COUNT
542 };
543 
544 //===----------------------------------------------------------------------===//
545 // Parser for Clause List
546 //===----------------------------------------------------------------------===//
547 
548 /// Parse a list of clauses. The clauses can appear in any order, but their
549 /// operand segment indices are in the same order that they are passed in the
550 /// `clauses` list. The operand segments are added over the prevSegments
551 
552 /// clause-list ::= clause clause-list | empty
553 /// clause ::= if | num-threads | private | firstprivate | lastprivate |
554 ///            shared | copyin | allocate | default | proc-bind | reduction |
555 ///            nowait | linear | schedule | collapse | order | ordered |
556 ///            inclusive
557 /// if ::= `if` `(` ssa-id-and-type `)`
558 /// num-threads ::= `num_threads` `(` ssa-id-and-type `)`
559 /// private ::= `private` operand-and-type-list
560 /// firstprivate ::= `firstprivate` operand-and-type-list
561 /// lastprivate ::= `lastprivate` operand-and-type-list
562 /// shared ::= `shared` operand-and-type-list
563 /// copyin ::= `copyin` operand-and-type-list
564 /// allocate ::= `allocate` `(` allocate-operand-list `)`
565 /// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`)
566 /// proc-bind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)`
567 /// reduction ::= `reduction` `(` reduction-entry-list `)`
568 /// nowait ::= `nowait`
569 /// linear ::= `linear` `(` linear-list `)`
570 /// schedule ::= `schedule` `(` sched-list `)`
571 /// collapse ::= `collapse` `(` ssa-id-and-type `)`
572 /// order ::= `order` `(` `concurrent` `)`
573 /// ordered ::= `ordered` `(` ssa-id-and-type `)`
574 /// inclusive ::= `inclusive`
575 ///
576 /// Note that each clause can only appear once in the clase-list.
577 static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
578                                 SmallVectorImpl<ClauseType> &clauses,
579                                 SmallVectorImpl<int> &segments) {
580 
581   // Check done[clause] to see if it has been parsed already
582   llvm::BitVector done(ClauseType::COUNT, false);
583 
584   // See pos[clause] to get position of clause in operand segments
585   SmallVector<int> pos(ClauseType::COUNT, -1);
586 
587   // Stores the last parsed clause keyword
588   StringRef clauseKeyword;
589   StringRef opName = result.name.getStringRef();
590 
591   // Containers for storing operands, types and attributes for various clauses
592   std::pair<OpAsmParser::OperandType, Type> ifCond;
593   std::pair<OpAsmParser::OperandType, Type> numThreads;
594 
595   SmallVector<OpAsmParser::OperandType> privates, firstprivates, lastprivates,
596       shareds, copyins;
597   SmallVector<Type> privateTypes, firstprivateTypes, lastprivateTypes,
598       sharedTypes, copyinTypes;
599 
600   SmallVector<OpAsmParser::OperandType> allocates, allocators;
601   SmallVector<Type> allocateTypes, allocatorTypes;
602 
603   SmallVector<SymbolRefAttr> reductionSymbols;
604   SmallVector<OpAsmParser::OperandType> reductionVars;
605   SmallVector<Type> reductionVarTypes;
606 
607   SmallVector<OpAsmParser::OperandType> linears;
608   SmallVector<Type> linearTypes;
609   SmallVector<OpAsmParser::OperandType> linearSteps;
610 
611   SmallString<8> schedule;
612   SmallVector<SmallString<12>> modifiers;
613   Optional<OpAsmParser::OperandType> scheduleChunkSize;
614 
615   // Compute the position of clauses in operand segments
616   int currPos = 0;
617   for (ClauseType clause : clauses) {
618 
619     // Skip the following clauses - they do not take any position in operand
620     // segments
621     if (clause == defaultClause || clause == procBindClause ||
622         clause == nowaitClause || clause == collapseClause ||
623         clause == orderClause || clause == orderedClause)
624       continue;
625 
626     pos[clause] = currPos++;
627 
628     // For the following clauses, two positions are reserved in the operand
629     // segments
630     if (clause == allocateClause || clause == linearClause)
631       currPos++;
632   }
633 
634   SmallVector<int> clauseSegments(currPos);
635 
636   // Helper function to check if a clause is allowed/repeated or not
637   auto checkAllowed = [&](ClauseType clause,
638                           bool allowRepeat = false) -> ParseResult {
639     if (!llvm::is_contained(clauses, clause))
640       return parser.emitError(parser.getCurrentLocation())
641              << clauseKeyword << " is not a valid clause for the " << opName
642              << " operation";
643     if (done[clause] && !allowRepeat)
644       return parser.emitError(parser.getCurrentLocation())
645              << "at most one " << clauseKeyword << " clause can appear on the "
646              << opName << " operation";
647     done[clause] = true;
648     return success();
649   };
650 
651   while (succeeded(parser.parseOptionalKeyword(&clauseKeyword))) {
652     if (clauseKeyword == "if") {
653       if (checkAllowed(ifClause) || parser.parseLParen() ||
654           parser.parseOperand(ifCond.first) ||
655           parser.parseColonType(ifCond.second) || parser.parseRParen())
656         return failure();
657       clauseSegments[pos[ifClause]] = 1;
658     } else if (clauseKeyword == "num_threads") {
659       if (checkAllowed(numThreadsClause) || parser.parseLParen() ||
660           parser.parseOperand(numThreads.first) ||
661           parser.parseColonType(numThreads.second) || parser.parseRParen())
662         return failure();
663       clauseSegments[pos[numThreadsClause]] = 1;
664     } else if (clauseKeyword == "private") {
665       if (checkAllowed(privateClause) ||
666           parseOperandAndTypeList(parser, privates, privateTypes))
667         return failure();
668       clauseSegments[pos[privateClause]] = privates.size();
669     } else if (clauseKeyword == "firstprivate") {
670       if (checkAllowed(firstprivateClause) ||
671           parseOperandAndTypeList(parser, firstprivates, firstprivateTypes))
672         return failure();
673       clauseSegments[pos[firstprivateClause]] = firstprivates.size();
674     } else if (clauseKeyword == "lastprivate") {
675       if (checkAllowed(lastprivateClause) ||
676           parseOperandAndTypeList(parser, lastprivates, lastprivateTypes))
677         return failure();
678       clauseSegments[pos[lastprivateClause]] = lastprivates.size();
679     } else if (clauseKeyword == "shared") {
680       if (checkAllowed(sharedClause) ||
681           parseOperandAndTypeList(parser, shareds, sharedTypes))
682         return failure();
683       clauseSegments[pos[sharedClause]] = shareds.size();
684     } else if (clauseKeyword == "copyin") {
685       if (checkAllowed(copyinClause) ||
686           parseOperandAndTypeList(parser, copyins, copyinTypes))
687         return failure();
688       clauseSegments[pos[copyinClause]] = copyins.size();
689     } else if (clauseKeyword == "allocate") {
690       if (checkAllowed(allocateClause) ||
691           parseAllocateAndAllocator(parser, allocates, allocateTypes,
692                                     allocators, allocatorTypes))
693         return failure();
694       clauseSegments[pos[allocateClause]] = allocates.size();
695       clauseSegments[pos[allocateClause] + 1] = allocators.size();
696     } else if (clauseKeyword == "default") {
697       StringRef defval;
698       if (checkAllowed(defaultClause) || parser.parseLParen() ||
699           parser.parseKeyword(&defval) || parser.parseRParen())
700         return failure();
701       // The def prefix is required for the attribute as "private" is a keyword
702       // in C++.
703       auto attr = parser.getBuilder().getStringAttr("def" + defval);
704       result.addAttribute("default_val", attr);
705     } else if (clauseKeyword == "proc_bind") {
706       StringRef bind;
707       if (checkAllowed(procBindClause) || parser.parseLParen() ||
708           parser.parseKeyword(&bind) || parser.parseRParen())
709         return failure();
710       auto attr = parser.getBuilder().getStringAttr(bind);
711       result.addAttribute("proc_bind_val", attr);
712     } else if (clauseKeyword == "reduction") {
713       if (checkAllowed(reductionClause) ||
714           parseReductionVarList(parser, reductionSymbols, reductionVars,
715                                 reductionVarTypes))
716         return failure();
717       clauseSegments[pos[reductionClause]] = reductionVars.size();
718     } else if (clauseKeyword == "nowait") {
719       if (checkAllowed(nowaitClause))
720         return failure();
721       auto attr = UnitAttr::get(parser.getBuilder().getContext());
722       result.addAttribute("nowait", attr);
723     } else if (clauseKeyword == "linear") {
724       if (checkAllowed(linearClause) ||
725           parseLinearClause(parser, linears, linearTypes, linearSteps))
726         return failure();
727       clauseSegments[pos[linearClause]] = linears.size();
728       clauseSegments[pos[linearClause] + 1] = linearSteps.size();
729     } else if (clauseKeyword == "schedule") {
730       if (checkAllowed(scheduleClause) ||
731           parseScheduleClause(parser, schedule, modifiers, scheduleChunkSize))
732         return failure();
733       if (scheduleChunkSize) {
734         clauseSegments[pos[scheduleClause]] = 1;
735       }
736     } else if (clauseKeyword == "collapse") {
737       auto type = parser.getBuilder().getI64Type();
738       mlir::IntegerAttr attr;
739       if (checkAllowed(collapseClause) || parser.parseLParen() ||
740           parser.parseAttribute(attr, type) || parser.parseRParen())
741         return failure();
742       result.addAttribute("collapse_val", attr);
743     } else if (clauseKeyword == "ordered") {
744       mlir::IntegerAttr attr;
745       if (checkAllowed(orderedClause))
746         return failure();
747       if (succeeded(parser.parseOptionalLParen())) {
748         auto type = parser.getBuilder().getI64Type();
749         if (parser.parseAttribute(attr, type) || parser.parseRParen())
750           return failure();
751       } else {
752         // Use 0 to represent no ordered parameter was specified
753         attr = parser.getBuilder().getI64IntegerAttr(0);
754       }
755       result.addAttribute("ordered_val", attr);
756     } else if (clauseKeyword == "order") {
757       StringRef order;
758       if (checkAllowed(orderClause) || parser.parseLParen() ||
759           parser.parseKeyword(&order) || parser.parseRParen())
760         return failure();
761       auto attr = parser.getBuilder().getStringAttr(order);
762       result.addAttribute("order_val", attr);
763     } else if (clauseKeyword == "memory_order") {
764       StringRef memoryOrder;
765       if (checkAllowed(memoryOrderClause) || parser.parseLParen() ||
766           parser.parseKeyword(&memoryOrder) || parser.parseRParen())
767         return failure();
768       result.addAttribute("memory_order",
769                           parser.getBuilder().getStringAttr(memoryOrder));
770     } else if (clauseKeyword == "hint") {
771       IntegerAttr hint;
772       if (checkAllowed(hintClause) ||
773           parseSynchronizationHint(parser, hint, false))
774         return failure();
775       result.addAttribute("hint", hint);
776     } else {
777       return parser.emitError(parser.getNameLoc())
778              << clauseKeyword << " is not a valid clause";
779     }
780   }
781 
782   // Add if parameter.
783   if (done[ifClause] && clauseSegments[pos[ifClause]] &&
784       failed(
785           parser.resolveOperand(ifCond.first, ifCond.second, result.operands)))
786     return failure();
787 
788   // Add num_threads parameter.
789   if (done[numThreadsClause] && clauseSegments[pos[numThreadsClause]] &&
790       failed(parser.resolveOperand(numThreads.first, numThreads.second,
791                                    result.operands)))
792     return failure();
793 
794   // Add private parameters.
795   if (done[privateClause] && clauseSegments[pos[privateClause]] &&
796       failed(parser.resolveOperands(privates, privateTypes,
797                                     privates[0].location, result.operands)))
798     return failure();
799 
800   // Add firstprivate parameters.
801   if (done[firstprivateClause] && clauseSegments[pos[firstprivateClause]] &&
802       failed(parser.resolveOperands(firstprivates, firstprivateTypes,
803                                     firstprivates[0].location,
804                                     result.operands)))
805     return failure();
806 
807   // Add lastprivate parameters.
808   if (done[lastprivateClause] && clauseSegments[pos[lastprivateClause]] &&
809       failed(parser.resolveOperands(lastprivates, lastprivateTypes,
810                                     lastprivates[0].location, result.operands)))
811     return failure();
812 
813   // Add shared parameters.
814   if (done[sharedClause] && clauseSegments[pos[sharedClause]] &&
815       failed(parser.resolveOperands(shareds, sharedTypes, shareds[0].location,
816                                     result.operands)))
817     return failure();
818 
819   // Add copyin parameters.
820   if (done[copyinClause] && clauseSegments[pos[copyinClause]] &&
821       failed(parser.resolveOperands(copyins, copyinTypes, copyins[0].location,
822                                     result.operands)))
823     return failure();
824 
825   // Add allocate parameters.
826   if (done[allocateClause] && clauseSegments[pos[allocateClause]] &&
827       failed(parser.resolveOperands(allocates, allocateTypes,
828                                     allocates[0].location, result.operands)))
829     return failure();
830 
831   // Add allocator parameters.
832   if (done[allocateClause] && clauseSegments[pos[allocateClause] + 1] &&
833       failed(parser.resolveOperands(allocators, allocatorTypes,
834                                     allocators[0].location, result.operands)))
835     return failure();
836 
837   // Add reduction parameters and symbols
838   if (done[reductionClause] && clauseSegments[pos[reductionClause]]) {
839     if (failed(parser.resolveOperands(reductionVars, reductionVarTypes,
840                                       parser.getNameLoc(), result.operands)))
841       return failure();
842 
843     SmallVector<Attribute> reductions(reductionSymbols.begin(),
844                                       reductionSymbols.end());
845     result.addAttribute("reductions",
846                         parser.getBuilder().getArrayAttr(reductions));
847   }
848 
849   // Add linear parameters
850   if (done[linearClause] && clauseSegments[pos[linearClause]]) {
851     auto linearStepType = parser.getBuilder().getI32Type();
852     SmallVector<Type> linearStepTypes(linearSteps.size(), linearStepType);
853     if (failed(parser.resolveOperands(linears, linearTypes, linears[0].location,
854                                       result.operands)) ||
855         failed(parser.resolveOperands(linearSteps, linearStepTypes,
856                                       linearSteps[0].location,
857                                       result.operands)))
858       return failure();
859   }
860 
861   // Add schedule parameters
862   if (done[scheduleClause] && !schedule.empty()) {
863     schedule[0] = llvm::toUpper(schedule[0]);
864     auto attr = parser.getBuilder().getStringAttr(schedule);
865     result.addAttribute("schedule_val", attr);
866     if (!modifiers.empty()) {
867       auto mod = parser.getBuilder().getStringAttr(modifiers[0]);
868       result.addAttribute("schedule_modifier", mod);
869       // Only SIMD attribute is allowed here!
870       if (modifiers.size() > 1) {
871         assert(symbolizeScheduleModifier(modifiers[1]) ==
872                mlir::omp::ScheduleModifier::simd);
873         auto attr = UnitAttr::get(parser.getBuilder().getContext());
874         result.addAttribute("simd_modifier", attr);
875       }
876     }
877     if (scheduleChunkSize) {
878       auto chunkSizeType = parser.getBuilder().getI32Type();
879       parser.resolveOperand(*scheduleChunkSize, chunkSizeType, result.operands);
880     }
881   }
882 
883   segments.insert(segments.end(), clauseSegments.begin(), clauseSegments.end());
884 
885   return success();
886 }
887 
888 /// Parses a parallel operation.
889 ///
890 /// operation ::= `omp.parallel` clause-list
891 /// clause-list ::= clause | clause clause-list
892 /// clause ::= if | num-threads | private | firstprivate | shared | copyin |
893 ///            allocate | default | proc-bind
894 ///
895 static ParseResult parseParallelOp(OpAsmParser &parser,
896                                    OperationState &result) {
897   SmallVector<ClauseType> clauses = {
898       ifClause,           numThreadsClause, privateClause,
899       firstprivateClause, sharedClause,     copyinClause,
900       allocateClause,     defaultClause,    procBindClause};
901 
902   SmallVector<int> segments;
903 
904   if (failed(parseClauses(parser, result, clauses, segments)))
905     return failure();
906 
907   result.addAttribute("operand_segment_sizes",
908                       parser.getBuilder().getI32VectorAttr(segments));
909 
910   Region *body = result.addRegion();
911   SmallVector<OpAsmParser::OperandType> regionArgs;
912   SmallVector<Type> regionArgTypes;
913   if (parser.parseRegion(*body, regionArgs, regionArgTypes))
914     return failure();
915   return success();
916 }
917 
918 //===----------------------------------------------------------------------===//
919 // Parser, printer and verifier for SectionsOp
920 //===----------------------------------------------------------------------===//
921 
922 /// Parses an OpenMP Sections operation
923 ///
924 /// sections ::= `omp.sections` clause-list
925 /// clause-list ::= clause clause-list | empty
926 /// clause ::= private | firstprivate | lastprivate | reduction | allocate |
927 ///            nowait
928 static ParseResult parseSectionsOp(OpAsmParser &parser,
929                                    OperationState &result) {
930 
931   SmallVector<ClauseType> clauses = {privateClause,     firstprivateClause,
932                                      lastprivateClause, reductionClause,
933                                      allocateClause,    nowaitClause};
934 
935   SmallVector<int> segments;
936 
937   if (failed(parseClauses(parser, result, clauses, segments)))
938     return failure();
939 
940   result.addAttribute("operand_segment_sizes",
941                       parser.getBuilder().getI32VectorAttr(segments));
942 
943   // Now parse the body.
944   Region *body = result.addRegion();
945   if (parser.parseRegion(*body))
946     return failure();
947   return success();
948 }
949 
950 static void printSectionsOp(OpAsmPrinter &p, SectionsOp op) {
951   p << " ";
952   printDataVars(p, op.private_vars(), "private");
953   printDataVars(p, op.firstprivate_vars(), "firstprivate");
954   printDataVars(p, op.lastprivate_vars(), "lastprivate");
955 
956   if (!op.reduction_vars().empty())
957     printReductionVarList(p, op.reductions(), op.reduction_vars());
958 
959   if (!op.allocate_vars().empty())
960     printAllocateAndAllocator(p, op.allocate_vars(), op.allocators_vars());
961 
962   if (op.nowait())
963     p << "nowait ";
964 
965   p.printRegion(op.region());
966 }
967 
968 static LogicalResult verifySectionsOp(SectionsOp op) {
969 
970   // A list item may not appear in more than one clause on the same directive,
971   // except that it may be specified in both firstprivate and lastprivate
972   // clauses.
973   for (auto var : op.private_vars()) {
974     if (llvm::is_contained(op.firstprivate_vars(), var))
975       return op.emitOpError()
976              << "operand used in both private and firstprivate clauses";
977     if (llvm::is_contained(op.lastprivate_vars(), var))
978       return op.emitOpError()
979              << "operand used in both private and lastprivate clauses";
980   }
981 
982   if (op.allocate_vars().size() != op.allocators_vars().size())
983     return op.emitError(
984         "expected equal sizes for allocate and allocator variables");
985 
986   for (auto &inst : *op.region().begin()) {
987     if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst)))
988       op.emitOpError()
989           << "expected omp.section op or terminator op inside region";
990   }
991 
992   return verifyReductionVarList(op, op.reductions(), op.reduction_vars());
993 }
994 
995 /// Parses an OpenMP Workshare Loop operation
996 ///
997 /// wsloop ::= `omp.wsloop` loop-control clause-list
998 /// loop-control ::= `(` ssa-id-list `)` `:` type `=`  loop-bounds
999 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` inclusive? steps
1000 /// steps := `step` `(`ssa-id-list`)`
1001 /// clause-list ::= clause clause-list | empty
1002 /// clause ::= private | firstprivate | lastprivate | linear | schedule |
1003 //             collapse | nowait | ordered | order | reduction
1004 static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) {
1005 
1006   // Parse an opening `(` followed by induction variables followed by `)`
1007   SmallVector<OpAsmParser::OperandType> ivs;
1008   if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
1009                                      OpAsmParser::Delimiter::Paren))
1010     return failure();
1011 
1012   int numIVs = static_cast<int>(ivs.size());
1013   Type loopVarType;
1014   if (parser.parseColonType(loopVarType))
1015     return failure();
1016 
1017   // Parse loop bounds.
1018   SmallVector<OpAsmParser::OperandType> lower;
1019   if (parser.parseEqual() ||
1020       parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) ||
1021       parser.resolveOperands(lower, loopVarType, result.operands))
1022     return failure();
1023 
1024   SmallVector<OpAsmParser::OperandType> upper;
1025   if (parser.parseKeyword("to") ||
1026       parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) ||
1027       parser.resolveOperands(upper, loopVarType, result.operands))
1028     return failure();
1029 
1030   if (succeeded(parser.parseOptionalKeyword("inclusive"))) {
1031     auto attr = UnitAttr::get(parser.getBuilder().getContext());
1032     result.addAttribute("inclusive", attr);
1033   }
1034 
1035   // Parse step values.
1036   SmallVector<OpAsmParser::OperandType> steps;
1037   if (parser.parseKeyword("step") ||
1038       parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) ||
1039       parser.resolveOperands(steps, loopVarType, result.operands))
1040     return failure();
1041 
1042   SmallVector<ClauseType> clauses = {
1043       privateClause,   firstprivateClause, lastprivateClause, linearClause,
1044       reductionClause, collapseClause,     orderClause,       orderedClause,
1045       nowaitClause,    scheduleClause};
1046   SmallVector<int> segments{numIVs, numIVs, numIVs};
1047   if (failed(parseClauses(parser, result, clauses, segments)))
1048     return failure();
1049 
1050   result.addAttribute("operand_segment_sizes",
1051                       parser.getBuilder().getI32VectorAttr(segments));
1052 
1053   // Now parse the body.
1054   Region *body = result.addRegion();
1055   SmallVector<Type> ivTypes(numIVs, loopVarType);
1056   SmallVector<OpAsmParser::OperandType> blockArgs(ivs);
1057   if (parser.parseRegion(*body, blockArgs, ivTypes))
1058     return failure();
1059   return success();
1060 }
1061 
1062 static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) {
1063   auto args = op.getRegion().front().getArguments();
1064   p << " (" << args << ") : " << args[0].getType() << " = (" << op.lowerBound()
1065     << ") to (" << op.upperBound() << ") ";
1066   if (op.inclusive()) {
1067     p << "inclusive ";
1068   }
1069   p << "step (" << op.step() << ") ";
1070 
1071   printDataVars(p, op.private_vars(), "private");
1072   printDataVars(p, op.firstprivate_vars(), "firstprivate");
1073   printDataVars(p, op.lastprivate_vars(), "lastprivate");
1074 
1075   if (!op.linear_vars().empty())
1076     printLinearClause(p, op.linear_vars(), op.linear_step_vars());
1077 
1078   if (auto sched = op.schedule_val())
1079     printScheduleClause(p, sched.getValue(), op.schedule_modifier(),
1080                         op.simd_modifier(), op.schedule_chunk_var());
1081 
1082   if (auto collapse = op.collapse_val())
1083     p << "collapse(" << collapse << ") ";
1084 
1085   if (op.nowait())
1086     p << "nowait ";
1087 
1088   if (auto ordered = op.ordered_val())
1089     p << "ordered(" << ordered << ") ";
1090 
1091   if (auto order = op.order_val())
1092     p << "order(" << order << ") ";
1093 
1094   if (!op.reduction_vars().empty())
1095     printReductionVarList(p, op.reductions(), op.reduction_vars());
1096 
1097   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
1098 }
1099 
1100 //===----------------------------------------------------------------------===//
1101 // ReductionOp
1102 //===----------------------------------------------------------------------===//
1103 
1104 static ParseResult parseAtomicReductionRegion(OpAsmParser &parser,
1105                                               Region &region) {
1106   if (parser.parseOptionalKeyword("atomic"))
1107     return success();
1108   return parser.parseRegion(region);
1109 }
1110 
1111 static void printAtomicReductionRegion(OpAsmPrinter &printer,
1112                                        ReductionDeclareOp op, Region &region) {
1113   if (region.empty())
1114     return;
1115   printer << "atomic ";
1116   printer.printRegion(region);
1117 }
1118 
1119 static LogicalResult verifyReductionDeclareOp(ReductionDeclareOp op) {
1120   if (op.initializerRegion().empty())
1121     return op.emitOpError() << "expects non-empty initializer region";
1122   Block &initializerEntryBlock = op.initializerRegion().front();
1123   if (initializerEntryBlock.getNumArguments() != 1 ||
1124       initializerEntryBlock.getArgument(0).getType() != op.type()) {
1125     return op.emitOpError() << "expects initializer region with one argument "
1126                                "of the reduction type";
1127   }
1128 
1129   for (YieldOp yieldOp : op.initializerRegion().getOps<YieldOp>()) {
1130     if (yieldOp.results().size() != 1 ||
1131         yieldOp.results().getTypes()[0] != op.type())
1132       return op.emitOpError() << "expects initializer region to yield a value "
1133                                  "of the reduction type";
1134   }
1135 
1136   if (op.reductionRegion().empty())
1137     return op.emitOpError() << "expects non-empty reduction region";
1138   Block &reductionEntryBlock = op.reductionRegion().front();
1139   if (reductionEntryBlock.getNumArguments() != 2 ||
1140       reductionEntryBlock.getArgumentTypes()[0] !=
1141           reductionEntryBlock.getArgumentTypes()[1] ||
1142       reductionEntryBlock.getArgumentTypes()[0] != op.type())
1143     return op.emitOpError() << "expects reduction region with two arguments of "
1144                                "the reduction type";
1145   for (YieldOp yieldOp : op.reductionRegion().getOps<YieldOp>()) {
1146     if (yieldOp.results().size() != 1 ||
1147         yieldOp.results().getTypes()[0] != op.type())
1148       return op.emitOpError() << "expects reduction region to yield a value "
1149                                  "of the reduction type";
1150   }
1151 
1152   if (op.atomicReductionRegion().empty())
1153     return success();
1154 
1155   Block &atomicReductionEntryBlock = op.atomicReductionRegion().front();
1156   if (atomicReductionEntryBlock.getNumArguments() != 2 ||
1157       atomicReductionEntryBlock.getArgumentTypes()[0] !=
1158           atomicReductionEntryBlock.getArgumentTypes()[1])
1159     return op.emitOpError() << "expects atomic reduction region with two "
1160                                "arguments of the same type";
1161   auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0]
1162                      .dyn_cast<PointerLikeType>();
1163   if (!ptrType || ptrType.getElementType() != op.type())
1164     return op.emitOpError() << "expects atomic reduction region arguments to "
1165                                "be accumulators containing the reduction type";
1166   return success();
1167 }
1168 
1169 static LogicalResult verifyReductionOp(ReductionOp op) {
1170   // TODO: generalize this to an op interface when there is more than one op
1171   // that supports reductions.
1172   auto container = op->getParentOfType<WsLoopOp>();
1173   for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i)
1174     if (container.reduction_vars()[i] == op.accumulator())
1175       return success();
1176 
1177   return op.emitOpError() << "the accumulator is not used by the parent";
1178 }
1179 
1180 //===----------------------------------------------------------------------===//
1181 // WsLoopOp
1182 //===----------------------------------------------------------------------===//
1183 
1184 void WsLoopOp::build(OpBuilder &builder, OperationState &state,
1185                      ValueRange lowerBound, ValueRange upperBound,
1186                      ValueRange step, ArrayRef<NamedAttribute> attributes) {
1187   build(builder, state, TypeRange(), lowerBound, upperBound, step,
1188         /*privateVars=*/ValueRange(),
1189         /*firstprivateVars=*/ValueRange(), /*lastprivate_vars=*/ValueRange(),
1190         /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
1191         /*reduction_vars=*/ValueRange(), /*schedule_val=*/nullptr,
1192         /*schedule_chunk_var=*/nullptr, /*collapse_val=*/nullptr,
1193         /*nowait=*/nullptr, /*ordered_val=*/nullptr, /*order_val=*/nullptr,
1194         /*inclusive=*/nullptr, /*buildBody=*/false);
1195   state.addAttributes(attributes);
1196 }
1197 
1198 void WsLoopOp::build(OpBuilder &, OperationState &state, TypeRange resultTypes,
1199                      ValueRange operands, ArrayRef<NamedAttribute> attributes) {
1200   state.addOperands(operands);
1201   state.addAttributes(attributes);
1202   (void)state.addRegion();
1203   assert(resultTypes.empty() && "mismatched number of return types");
1204   state.addTypes(resultTypes);
1205 }
1206 
1207 void WsLoopOp::build(OpBuilder &builder, OperationState &result,
1208                      TypeRange typeRange, ValueRange lowerBounds,
1209                      ValueRange upperBounds, ValueRange steps,
1210                      ValueRange privateVars, ValueRange firstprivateVars,
1211                      ValueRange lastprivateVars, ValueRange linearVars,
1212                      ValueRange linearStepVars, ValueRange reductionVars,
1213                      StringAttr scheduleVal, Value scheduleChunkVar,
1214                      IntegerAttr collapseVal, UnitAttr nowait,
1215                      IntegerAttr orderedVal, StringAttr orderVal,
1216                      UnitAttr inclusive, bool buildBody) {
1217   result.addOperands(lowerBounds);
1218   result.addOperands(upperBounds);
1219   result.addOperands(steps);
1220   result.addOperands(privateVars);
1221   result.addOperands(firstprivateVars);
1222   result.addOperands(linearVars);
1223   result.addOperands(linearStepVars);
1224   if (scheduleChunkVar)
1225     result.addOperands(scheduleChunkVar);
1226 
1227   if (scheduleVal)
1228     result.addAttribute("schedule_val", scheduleVal);
1229   if (collapseVal)
1230     result.addAttribute("collapse_val", collapseVal);
1231   if (nowait)
1232     result.addAttribute("nowait", nowait);
1233   if (orderedVal)
1234     result.addAttribute("ordered_val", orderedVal);
1235   if (orderVal)
1236     result.addAttribute("order", orderVal);
1237   if (inclusive)
1238     result.addAttribute("inclusive", inclusive);
1239   result.addAttribute(
1240       WsLoopOp::getOperandSegmentSizeAttr(),
1241       builder.getI32VectorAttr(
1242           {static_cast<int32_t>(lowerBounds.size()),
1243            static_cast<int32_t>(upperBounds.size()),
1244            static_cast<int32_t>(steps.size()),
1245            static_cast<int32_t>(privateVars.size()),
1246            static_cast<int32_t>(firstprivateVars.size()),
1247            static_cast<int32_t>(lastprivateVars.size()),
1248            static_cast<int32_t>(linearVars.size()),
1249            static_cast<int32_t>(linearStepVars.size()),
1250            static_cast<int32_t>(reductionVars.size()),
1251            static_cast<int32_t>(scheduleChunkVar != nullptr ? 1 : 0)}));
1252 
1253   Region *bodyRegion = result.addRegion();
1254   if (buildBody) {
1255     OpBuilder::InsertionGuard guard(builder);
1256     unsigned numIVs = steps.size();
1257     SmallVector<Type, 8> argTypes(numIVs, steps.getType().front());
1258     builder.createBlock(bodyRegion, {}, argTypes);
1259   }
1260 }
1261 
1262 static LogicalResult verifyWsLoopOp(WsLoopOp op) {
1263   return verifyReductionVarList(op, op.reductions(), op.reduction_vars());
1264 }
1265 
1266 //===----------------------------------------------------------------------===//
1267 // Verifier for critical construct (2.17.1)
1268 //===----------------------------------------------------------------------===//
1269 
1270 static LogicalResult verifyCriticalDeclareOp(CriticalDeclareOp op) {
1271   return verifySynchronizationHint(op, op.hint());
1272 }
1273 
1274 static LogicalResult verifyCriticalOp(CriticalOp op) {
1275 
1276   if (op.nameAttr()) {
1277     auto symbolRef = op.nameAttr().cast<SymbolRefAttr>();
1278     auto decl =
1279         SymbolTable::lookupNearestSymbolFrom<CriticalDeclareOp>(op, symbolRef);
1280     if (!decl) {
1281       return op.emitOpError() << "expected symbol reference " << symbolRef
1282                               << " to point to a critical declaration";
1283     }
1284   }
1285 
1286   return success();
1287 }
1288 
1289 //===----------------------------------------------------------------------===//
1290 // Verifier for ordered construct
1291 //===----------------------------------------------------------------------===//
1292 
1293 static LogicalResult verifyOrderedOp(OrderedOp op) {
1294   auto container = op->getParentOfType<WsLoopOp>();
1295   if (!container || !container.ordered_valAttr() ||
1296       container.ordered_valAttr().getInt() == 0)
1297     return op.emitOpError() << "ordered depend directive must be closely "
1298                             << "nested inside a worksharing-loop with ordered "
1299                             << "clause with parameter present";
1300 
1301   if (container.ordered_valAttr().getInt() !=
1302       (int64_t)op.num_loops_val().getValue())
1303     return op.emitOpError() << "number of variables in depend clause does not "
1304                             << "match number of iteration variables in the "
1305                             << "doacross loop";
1306 
1307   return success();
1308 }
1309 
1310 static LogicalResult verifyOrderedRegionOp(OrderedRegionOp op) {
1311   // TODO: The code generation for ordered simd directive is not supported yet.
1312   if (op.simd())
1313     return failure();
1314 
1315   if (auto container = op->getParentOfType<WsLoopOp>()) {
1316     if (!container.ordered_valAttr() ||
1317         container.ordered_valAttr().getInt() != 0)
1318       return op.emitOpError() << "ordered region must be closely nested inside "
1319                               << "a worksharing-loop region with an ordered "
1320                               << "clause without parameter present";
1321   }
1322 
1323   return success();
1324 }
1325 
1326 //===----------------------------------------------------------------------===//
1327 // AtomicReadOp
1328 //===----------------------------------------------------------------------===//
1329 
1330 /// Parser for AtomicReadOp
1331 ///
1332 /// operation ::= `omp.atomic.read` atomic-clause-list address `->` result-type
1333 /// address ::= operand `:` type
1334 static ParseResult parseAtomicReadOp(OpAsmParser &parser,
1335                                      OperationState &result) {
1336   OpAsmParser::OperandType x, v;
1337   Type addressType;
1338   SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
1339   SmallVector<int> segments;
1340 
1341   if (parser.parseOperand(v) || parser.parseEqual() || parser.parseOperand(x) ||
1342       parseClauses(parser, result, clauses, segments) ||
1343       parser.parseColonType(addressType) ||
1344       parser.resolveOperand(x, addressType, result.operands) ||
1345       parser.resolveOperand(v, addressType, result.operands))
1346     return failure();
1347 
1348   return success();
1349 }
1350 
1351 /// Printer for AtomicReadOp
1352 static void printAtomicReadOp(OpAsmPrinter &p, AtomicReadOp op) {
1353   p << " " << op.v() << " = " << op.x() << " ";
1354   if (op.memory_order())
1355     p << "memory_order(" << op.memory_order().getValue() << ") ";
1356   if (op.hintAttr())
1357     printSynchronizationHint(p << " ", op, op.hintAttr());
1358   p << ": " << op.x().getType();
1359 }
1360 
1361 /// Verifier for AtomicReadOp
1362 static LogicalResult verifyAtomicReadOp(AtomicReadOp op) {
1363   if (op.memory_order()) {
1364     StringRef memOrder = op.memory_order().getValue();
1365     if (memOrder.equals("acq_rel") || memOrder.equals("release"))
1366       return op.emitError(
1367           "memory-order must not be acq_rel or release for atomic reads");
1368   }
1369   if (op.x() == op.v())
1370     return op.emitError(
1371         "read and write must not be to the same location for atomic reads");
1372   return verifySynchronizationHint(op, op.hint());
1373 }
1374 
1375 //===----------------------------------------------------------------------===//
1376 // AtomicWriteOp
1377 //===----------------------------------------------------------------------===//
1378 
1379 /// Parser for AtomicWriteOp
1380 ///
1381 /// operation ::= `omp.atomic.write` atomic-clause-list operands
1382 /// operands ::= address `,` value
1383 /// address ::= operand `:` type
1384 /// value ::= operand `:` type
1385 static ParseResult parseAtomicWriteOp(OpAsmParser &parser,
1386                                       OperationState &result) {
1387   OpAsmParser::OperandType address, value;
1388   Type addrType, valueType;
1389   SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
1390   SmallVector<int> segments;
1391 
1392   if (parser.parseOperand(address) || parser.parseEqual() ||
1393       parser.parseOperand(value) ||
1394       parseClauses(parser, result, clauses, segments) ||
1395       parser.parseColonType(addrType) || parser.parseComma() ||
1396       parser.parseType(valueType) ||
1397       parser.resolveOperand(address, addrType, result.operands) ||
1398       parser.resolveOperand(value, valueType, result.operands))
1399     return failure();
1400   return success();
1401 }
1402 
1403 /// Printer for AtomicWriteOp
1404 static void printAtomicWriteOp(OpAsmPrinter &p, AtomicWriteOp op) {
1405   p << " " << op.address() << " = " << op.value() << " ";
1406   if (op.memory_order())
1407     p << "memory_order(" << op.memory_order() << ") ";
1408   if (op.hintAttr())
1409     printSynchronizationHint(p, op, op.hintAttr());
1410   p << ": " << op.address().getType() << ", " << op.value().getType();
1411 }
1412 
1413 /// Verifier for AtomicWriteOp
1414 static LogicalResult verifyAtomicWriteOp(AtomicWriteOp op) {
1415   if (op.memory_order()) {
1416     StringRef memoryOrder = op.memory_order().getValue();
1417     if (memoryOrder.equals("acq_rel") || memoryOrder.equals("acquire"))
1418       return op.emitError(
1419           "memory-order must not be acq_rel or acquire for atomic writes");
1420   }
1421   return verifySynchronizationHint(op, op.hint());
1422 }
1423 
1424 //===----------------------------------------------------------------------===//
1425 // AtomicUpdateOp
1426 //===----------------------------------------------------------------------===//
1427 
1428 /// Parser for AtomicUpdateOp
1429 ///
1430 /// operation ::= `omp.atomic.update` atomic-clause-list region
1431 static ParseResult parseAtomicUpdateOp(OpAsmParser &parser,
1432                                        OperationState &result) {
1433   SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
1434   SmallVector<int> segments;
1435   OpAsmParser::OperandType x, y, z;
1436   Type xType, exprType;
1437   StringRef binOp;
1438 
1439   // x = y `op` z : xtype, exprtype
1440   if (parser.parseOperand(x) || parser.parseEqual() || parser.parseOperand(y) ||
1441       parser.parseKeyword(&binOp) || parser.parseOperand(z) ||
1442       parseClauses(parser, result, clauses, segments) || parser.parseColon() ||
1443       parser.parseType(xType) || parser.parseComma() ||
1444       parser.parseType(exprType) ||
1445       parser.resolveOperand(x, xType, result.operands)) {
1446     return failure();
1447   }
1448 
1449   auto binOpEnum = AtomicBinOpKindToEnum(binOp.upper());
1450   if (!binOpEnum)
1451     return parser.emitError(parser.getNameLoc())
1452            << "invalid atomic bin op in atomic update\n";
1453   auto attr =
1454       parser.getBuilder().getI64IntegerAttr((int64_t)binOpEnum.getValue());
1455   result.addAttribute("binop", attr);
1456 
1457   OpAsmParser::OperandType expr;
1458   if (x.name == y.name && x.number == y.number) {
1459     expr = z;
1460     result.addAttribute("isXBinopExpr", parser.getBuilder().getUnitAttr());
1461   } else if (x.name == z.name && x.number == z.number) {
1462     expr = y;
1463   } else {
1464     return parser.emitError(parser.getNameLoc())
1465            << "atomic update variable " << x.name
1466            << " not found in the RHS of the assignment statement in an"
1467               " atomic.update operation";
1468   }
1469   return parser.resolveOperand(expr, exprType, result.operands);
1470 }
1471 
1472 /// Printer for AtomicUpdateOp
1473 static void printAtomicUpdateOp(OpAsmPrinter &p, AtomicUpdateOp op) {
1474   p << " " << op.x() << " = ";
1475   Value y, z;
1476   if (op.isXBinopExpr()) {
1477     y = op.x();
1478     z = op.expr();
1479   } else {
1480     y = op.expr();
1481     z = op.x();
1482   }
1483   p << y << " " << AtomicBinOpKindToString(op.binop()).lower() << " " << z
1484     << " ";
1485   if (op.memory_order())
1486     p << "memory_order(" << op.memory_order() << ") ";
1487   if (op.hintAttr())
1488     printSynchronizationHint(p, op, op.hintAttr());
1489   p << ": " << op.x().getType() << ", " << op.expr().getType();
1490 }
1491 
1492 /// Verifier for AtomicUpdateOp
1493 static LogicalResult verifyAtomicUpdateOp(AtomicUpdateOp op) {
1494   if (op.memory_order()) {
1495     StringRef memoryOrder = op.memory_order().getValue();
1496     if (memoryOrder.equals("acq_rel") || memoryOrder.equals("acquire"))
1497       return op.emitError(
1498           "memory-order must not be acq_rel or acquire for atomic updates");
1499   }
1500   return success();
1501 }
1502 
1503 #define GET_OP_CLASSES
1504 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
1505