1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the OpenMP dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/OpImplementation.h"
18 #include "mlir/IR/OperationSupport.h"
19 
20 #include "llvm/ADT/BitVector.h"
21 #include "llvm/ADT/SmallString.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/ADT/StringSwitch.h"
25 #include <cstddef>
26 
27 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
28 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
29 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
30 
31 using namespace mlir;
32 using namespace mlir::omp;
33 
34 namespace {
35 /// Model for pointer-like types that already provide a `getElementType` method.
36 template <typename T>
37 struct PointerLikeModel
38     : public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> {
39   Type getElementType(Type pointer) const {
40     return pointer.cast<T>().getElementType();
41   }
42 };
43 } // namespace
44 
45 void OpenMPDialect::initialize() {
46   addOperations<
47 #define GET_OP_LIST
48 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
49       >();
50 
51   LLVM::LLVMPointerType::attachInterface<
52       PointerLikeModel<LLVM::LLVMPointerType>>(*getContext());
53   MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext());
54 }
55 
56 //===----------------------------------------------------------------------===//
57 // ParallelOp
58 //===----------------------------------------------------------------------===//
59 
60 void ParallelOp::build(OpBuilder &builder, OperationState &state,
61                        ArrayRef<NamedAttribute> attributes) {
62   ParallelOp::build(
63       builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
64       /*default_val=*/nullptr, /*private_vars=*/ValueRange(),
65       /*firstprivate_vars=*/ValueRange(), /*shared_vars=*/ValueRange(),
66       /*copyin_vars=*/ValueRange(), /*allocate_vars=*/ValueRange(),
67       /*allocators_vars=*/ValueRange(), /*proc_bind_val=*/nullptr);
68   state.addAttributes(attributes);
69 }
70 
71 //===----------------------------------------------------------------------===//
72 // Parser and printer for Operand and type list
73 //===----------------------------------------------------------------------===//
74 
75 /// Parse a list of operands with types.
76 ///
77 /// operand-and-type-list ::= `(` ssa-id-and-type-list `)`
78 /// ssa-id-and-type-list ::= ssa-id-and-type |
79 ///                          ssa-id-and-type `,` ssa-id-and-type-list
80 /// ssa-id-and-type ::= ssa-id `:` type
81 static ParseResult
82 parseOperandAndTypeList(OpAsmParser &parser,
83                         SmallVectorImpl<OpAsmParser::OperandType> &operands,
84                         SmallVectorImpl<Type> &types) {
85   return parser.parseCommaSeparatedList(
86       OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
87         OpAsmParser::OperandType operand;
88         Type type;
89         if (parser.parseOperand(operand) || parser.parseColonType(type))
90           return failure();
91         operands.push_back(operand);
92         types.push_back(type);
93         return success();
94       });
95 }
96 
97 /// Print an operand and type list with parentheses
98 static void printOperandAndTypeList(OpAsmPrinter &p, OperandRange operands) {
99   p << "(";
100   llvm::interleaveComma(
101       operands, p, [&](const Value &v) { p << v << " : " << v.getType(); });
102   p << ") ";
103 }
104 
105 /// Print data variables corresponding to a data-sharing clause `name`
106 static void printDataVars(OpAsmPrinter &p, OperandRange operands,
107                           StringRef name) {
108   if (!operands.empty()) {
109     p << name;
110     printOperandAndTypeList(p, operands);
111   }
112 }
113 
114 //===----------------------------------------------------------------------===//
115 // Parser and printer for Allocate Clause
116 //===----------------------------------------------------------------------===//
117 
118 /// Parse an allocate clause with allocators and a list of operands with types.
119 ///
120 /// allocate ::= `allocate` `(` allocate-operand-list `)`
121 /// allocate-operand-list :: = allocate-operand |
122 ///                            allocator-operand `,` allocate-operand-list
123 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
124 /// ssa-id-and-type ::= ssa-id `:` type
125 static ParseResult parseAllocateAndAllocator(
126     OpAsmParser &parser,
127     SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate,
128     SmallVectorImpl<Type> &typesAllocate,
129     SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator,
130     SmallVectorImpl<Type> &typesAllocator) {
131 
132   return parser.parseCommaSeparatedList(
133       OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
134         OpAsmParser::OperandType operand;
135         Type type;
136         if (parser.parseOperand(operand) || parser.parseColonType(type))
137           return failure();
138         operandsAllocator.push_back(operand);
139         typesAllocator.push_back(type);
140         if (parser.parseArrow())
141           return failure();
142         if (parser.parseOperand(operand) || parser.parseColonType(type))
143           return failure();
144 
145         operandsAllocate.push_back(operand);
146         typesAllocate.push_back(type);
147         return success();
148       });
149 }
150 
151 /// Print allocate clause
152 static void printAllocateAndAllocator(OpAsmPrinter &p,
153                                       OperandRange varsAllocate,
154                                       OperandRange varsAllocator) {
155   p << "allocate(";
156   for (unsigned i = 0; i < varsAllocate.size(); ++i) {
157     std::string separator = i == varsAllocate.size() - 1 ? ") " : ", ";
158     p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> ";
159     p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator;
160   }
161 }
162 
163 static LogicalResult verifyParallelOp(ParallelOp op) {
164   if (op.allocate_vars().size() != op.allocators_vars().size())
165     return op.emitError(
166         "expected equal sizes for allocate and allocator variables");
167   return success();
168 }
169 
170 static void printParallelOp(OpAsmPrinter &p, ParallelOp op) {
171   p << " ";
172   if (auto ifCond = op.if_expr_var())
173     p << "if(" << ifCond << " : " << ifCond.getType() << ") ";
174 
175   if (auto threads = op.num_threads_var())
176     p << "num_threads(" << threads << " : " << threads.getType() << ") ";
177 
178   printDataVars(p, op.private_vars(), "private");
179   printDataVars(p, op.firstprivate_vars(), "firstprivate");
180   printDataVars(p, op.shared_vars(), "shared");
181   printDataVars(p, op.copyin_vars(), "copyin");
182 
183   if (!op.allocate_vars().empty())
184     printAllocateAndAllocator(p, op.allocate_vars(), op.allocators_vars());
185 
186   if (auto def = op.default_val())
187     p << "default(" << def->drop_front(3) << ") ";
188 
189   if (auto bind = op.proc_bind_val())
190     p << "proc_bind(" << bind << ") ";
191 
192   p << ' ';
193   p.printRegion(op.getRegion());
194 }
195 
196 //===----------------------------------------------------------------------===//
197 // Parser and printer for Linear Clause
198 //===----------------------------------------------------------------------===//
199 
200 /// linear ::= `linear` `(` linear-list `)`
201 /// linear-list := linear-val | linear-val linear-list
202 /// linear-val := ssa-id-and-type `=` ssa-id-and-type
203 static ParseResult
204 parseLinearClause(OpAsmParser &parser,
205                   SmallVectorImpl<OpAsmParser::OperandType> &vars,
206                   SmallVectorImpl<Type> &types,
207                   SmallVectorImpl<OpAsmParser::OperandType> &stepVars) {
208   if (parser.parseLParen())
209     return failure();
210 
211   do {
212     OpAsmParser::OperandType var;
213     Type type;
214     OpAsmParser::OperandType stepVar;
215     if (parser.parseOperand(var) || parser.parseEqual() ||
216         parser.parseOperand(stepVar) || parser.parseColonType(type))
217       return failure();
218 
219     vars.push_back(var);
220     types.push_back(type);
221     stepVars.push_back(stepVar);
222   } while (succeeded(parser.parseOptionalComma()));
223 
224   if (parser.parseRParen())
225     return failure();
226 
227   return success();
228 }
229 
230 /// Print Linear Clause
231 static void printLinearClause(OpAsmPrinter &p, OperandRange linearVars,
232                               OperandRange linearStepVars) {
233   size_t linearVarsSize = linearVars.size();
234   p << "linear(";
235   for (unsigned i = 0; i < linearVarsSize; ++i) {
236     std::string separator = i == linearVarsSize - 1 ? ") " : ", ";
237     p << linearVars[i];
238     if (linearStepVars.size() > i)
239       p << " = " << linearStepVars[i];
240     p << " : " << linearVars[i].getType() << separator;
241   }
242 }
243 
244 //===----------------------------------------------------------------------===//
245 // Parser and printer for Schedule Clause
246 //===----------------------------------------------------------------------===//
247 
248 static ParseResult
249 verifyScheduleModifiers(OpAsmParser &parser,
250                         SmallVectorImpl<SmallString<12>> &modifiers) {
251   if (modifiers.size() > 2)
252     return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)";
253   for (const auto &mod : modifiers) {
254     // Translate the string. If it has no value, then it was not a valid
255     // modifier!
256     auto symbol = symbolizeScheduleModifier(mod);
257     if (!symbol.hasValue())
258       return parser.emitError(parser.getNameLoc())
259              << " unknown modifier type: " << mod;
260   }
261 
262   // If we have one modifier that is "simd", then stick a "none" modiifer in
263   // index 0.
264   if (modifiers.size() == 1) {
265     if (symbolizeScheduleModifier(modifiers[0]) ==
266         mlir::omp::ScheduleModifier::simd) {
267       modifiers.push_back(modifiers[0]);
268       modifiers[0] =
269           stringifyScheduleModifier(mlir::omp::ScheduleModifier::none);
270     }
271   } else if (modifiers.size() == 2) {
272     // If there are two modifier:
273     // First modifier should not be simd, second one should be simd
274     if (symbolizeScheduleModifier(modifiers[0]) ==
275             mlir::omp::ScheduleModifier::simd ||
276         symbolizeScheduleModifier(modifiers[1]) !=
277             mlir::omp::ScheduleModifier::simd)
278       return parser.emitError(parser.getNameLoc())
279              << " incorrect modifier order";
280   }
281   return success();
282 }
283 
284 /// schedule ::= `schedule` `(` sched-list `)`
285 /// sched-list ::= sched-val | sched-val sched-list |
286 ///                sched-val `,` sched-modifier
287 /// sched-val ::= sched-with-chunk | sched-wo-chunk
288 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
289 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
290 /// sched-wo-chunk ::=  `auto` | `runtime`
291 /// sched-modifier ::=  sched-mod-val | sched-mod-val `,` sched-mod-val
292 /// sched-mod-val ::=  `monotonic` | `nonmonotonic` | `simd` | `none`
293 static ParseResult
294 parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule,
295                     SmallVectorImpl<SmallString<12>> &modifiers,
296                     Optional<OpAsmParser::OperandType> &chunkSize) {
297   if (parser.parseLParen())
298     return failure();
299 
300   StringRef keyword;
301   if (parser.parseKeyword(&keyword))
302     return failure();
303 
304   schedule = keyword;
305   if (keyword == "static" || keyword == "dynamic" || keyword == "guided") {
306     if (succeeded(parser.parseOptionalEqual())) {
307       chunkSize = OpAsmParser::OperandType{};
308       if (parser.parseOperand(*chunkSize))
309         return failure();
310     } else {
311       chunkSize = llvm::NoneType::None;
312     }
313   } else if (keyword == "auto" || keyword == "runtime") {
314     chunkSize = llvm::NoneType::None;
315   } else {
316     return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
317   }
318 
319   // If there is a comma, we have one or more modifiers..
320   while (succeeded(parser.parseOptionalComma())) {
321     StringRef mod;
322     if (parser.parseKeyword(&mod))
323       return failure();
324     modifiers.push_back(mod);
325   }
326 
327   if (parser.parseRParen())
328     return failure();
329 
330   if (verifyScheduleModifiers(parser, modifiers))
331     return failure();
332 
333   return success();
334 }
335 
336 /// Print schedule clause
337 static void printScheduleClause(OpAsmPrinter &p, StringRef &sched,
338                                 llvm::Optional<StringRef> modifier, bool simd,
339                                 Value scheduleChunkVar) {
340   std::string schedLower = sched.lower();
341   p << "schedule(" << schedLower;
342   if (scheduleChunkVar)
343     p << " = " << scheduleChunkVar;
344   if (modifier && modifier.hasValue())
345     p << ", " << modifier;
346   if (simd)
347     p << ", simd";
348   p << ") ";
349 }
350 
351 //===----------------------------------------------------------------------===//
352 // Parser, printer and verifier for ReductionVarList
353 //===----------------------------------------------------------------------===//
354 
355 /// reduction ::= `reduction` `(` reduction-entry-list `)`
356 /// reduction-entry-list ::= reduction-entry
357 ///                        | reduction-entry-list `,` reduction-entry
358 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type
359 static ParseResult
360 parseReductionVarList(OpAsmParser &parser,
361                       SmallVectorImpl<SymbolRefAttr> &symbols,
362                       SmallVectorImpl<OpAsmParser::OperandType> &operands,
363                       SmallVectorImpl<Type> &types) {
364   if (failed(parser.parseLParen()))
365     return failure();
366 
367   do {
368     if (parser.parseAttribute(symbols.emplace_back()) || parser.parseArrow() ||
369         parser.parseOperand(operands.emplace_back()) ||
370         parser.parseColonType(types.emplace_back()))
371       return failure();
372   } while (succeeded(parser.parseOptionalComma()));
373   return parser.parseRParen();
374 }
375 
376 /// Print Reduction clause
377 static void printReductionVarList(OpAsmPrinter &p,
378                                   Optional<ArrayAttr> reductions,
379                                   OperandRange reductionVars) {
380   p << "reduction(";
381   for (unsigned i = 0, e = reductions->size(); i < e; ++i) {
382     if (i != 0)
383       p << ", ";
384     p << (*reductions)[i] << " -> " << reductionVars[i] << " : "
385       << reductionVars[i].getType();
386   }
387   p << ") ";
388 }
389 
390 /// Verifies Reduction Clause
391 static LogicalResult verifyReductionVarList(Operation *op,
392                                             Optional<ArrayAttr> reductions,
393                                             OperandRange reductionVars) {
394   if (!reductionVars.empty()) {
395     if (!reductions || reductions->size() != reductionVars.size())
396       return op->emitOpError()
397              << "expected as many reduction symbol references "
398                 "as reduction variables";
399   } else {
400     if (reductions)
401       return op->emitOpError() << "unexpected reduction symbol references";
402     return success();
403   }
404 
405   DenseSet<Value> accumulators;
406   for (auto args : llvm::zip(reductionVars, *reductions)) {
407     Value accum = std::get<0>(args);
408 
409     if (!accumulators.insert(accum).second)
410       return op->emitOpError() << "accumulator variable used more than once";
411 
412     Type varType = accum.getType().cast<PointerLikeType>();
413     auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>();
414     auto decl =
415         SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef);
416     if (!decl)
417       return op->emitOpError() << "expected symbol reference " << symbolRef
418                                << " to point to a reduction declaration";
419 
420     if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
421       return op->emitOpError()
422              << "expected accumulator (" << varType
423              << ") to be the same type as reduction declaration ("
424              << decl.getAccumulatorType() << ")";
425   }
426 
427   return success();
428 }
429 
430 //===----------------------------------------------------------------------===//
431 // Parser, printer and verifier for Synchronization Hint (2.17.12)
432 //===----------------------------------------------------------------------===//
433 
434 /// Parses a Synchronization Hint clause. The value of hint is an integer
435 /// which is a combination of different hints from `omp_sync_hint_t`.
436 ///
437 /// hint-clause = `hint` `(` hint-value `)`
438 static ParseResult parseSynchronizationHint(OpAsmParser &parser,
439                                             IntegerAttr &hintAttr,
440                                             bool parseKeyword = true) {
441   if (parseKeyword && failed(parser.parseOptionalKeyword("hint"))) {
442     hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
443     return success();
444   }
445 
446   if (failed(parser.parseLParen()))
447     return failure();
448   StringRef hintKeyword;
449   int64_t hint = 0;
450   do {
451     if (failed(parser.parseKeyword(&hintKeyword)))
452       return failure();
453     if (hintKeyword == "uncontended")
454       hint |= 1;
455     else if (hintKeyword == "contended")
456       hint |= 2;
457     else if (hintKeyword == "nonspeculative")
458       hint |= 4;
459     else if (hintKeyword == "speculative")
460       hint |= 8;
461     else
462       return parser.emitError(parser.getCurrentLocation())
463              << hintKeyword << " is not a valid hint";
464   } while (succeeded(parser.parseOptionalComma()));
465   if (failed(parser.parseRParen()))
466     return failure();
467   hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
468   return success();
469 }
470 
471 /// Prints a Synchronization Hint clause
472 static void printSynchronizationHint(OpAsmPrinter &p, Operation *op,
473                                      IntegerAttr hintAttr) {
474   int64_t hint = hintAttr.getInt();
475 
476   if (hint == 0)
477     return;
478 
479   // Helper function to get n-th bit from the right end of `value`
480   auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
481 
482   bool uncontended = bitn(hint, 0);
483   bool contended = bitn(hint, 1);
484   bool nonspeculative = bitn(hint, 2);
485   bool speculative = bitn(hint, 3);
486 
487   SmallVector<StringRef> hints;
488   if (uncontended)
489     hints.push_back("uncontended");
490   if (contended)
491     hints.push_back("contended");
492   if (nonspeculative)
493     hints.push_back("nonspeculative");
494   if (speculative)
495     hints.push_back("speculative");
496 
497   p << "hint(";
498   llvm::interleaveComma(hints, p);
499   p << ") ";
500 }
501 
502 /// Verifies a synchronization hint clause
503 static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
504 
505   // Helper function to get n-th bit from the right end of `value`
506   auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
507 
508   bool uncontended = bitn(hint, 0);
509   bool contended = bitn(hint, 1);
510   bool nonspeculative = bitn(hint, 2);
511   bool speculative = bitn(hint, 3);
512 
513   if (uncontended && contended)
514     return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
515                                 "omp_sync_hint_contended cannot be combined";
516   if (nonspeculative && speculative)
517     return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
518                                 "omp_sync_hint_speculative cannot be combined.";
519   return success();
520 }
521 
522 enum ClauseType {
523   ifClause,
524   numThreadsClause,
525   privateClause,
526   firstprivateClause,
527   lastprivateClause,
528   sharedClause,
529   copyinClause,
530   allocateClause,
531   defaultClause,
532   procBindClause,
533   reductionClause,
534   nowaitClause,
535   linearClause,
536   scheduleClause,
537   collapseClause,
538   orderClause,
539   orderedClause,
540   memoryOrderClause,
541   hintClause,
542   COUNT
543 };
544 
545 //===----------------------------------------------------------------------===//
546 // Parser for Clause List
547 //===----------------------------------------------------------------------===//
548 
549 /// Parse a list of clauses. The clauses can appear in any order, but their
550 /// operand segment indices are in the same order that they are passed in the
551 /// `clauses` list. The operand segments are added over the prevSegments
552 
553 /// clause-list ::= clause clause-list | empty
554 /// clause ::= if | num-threads | private | firstprivate | lastprivate |
555 ///            shared | copyin | allocate | default | proc-bind | reduction |
556 ///            nowait | linear | schedule | collapse | order | ordered |
557 ///            inclusive
558 /// if ::= `if` `(` ssa-id-and-type `)`
559 /// num-threads ::= `num_threads` `(` ssa-id-and-type `)`
560 /// private ::= `private` operand-and-type-list
561 /// firstprivate ::= `firstprivate` operand-and-type-list
562 /// lastprivate ::= `lastprivate` operand-and-type-list
563 /// shared ::= `shared` operand-and-type-list
564 /// copyin ::= `copyin` operand-and-type-list
565 /// allocate ::= `allocate` `(` allocate-operand-list `)`
566 /// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`)
567 /// proc-bind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)`
568 /// reduction ::= `reduction` `(` reduction-entry-list `)`
569 /// nowait ::= `nowait`
570 /// linear ::= `linear` `(` linear-list `)`
571 /// schedule ::= `schedule` `(` sched-list `)`
572 /// collapse ::= `collapse` `(` ssa-id-and-type `)`
573 /// order ::= `order` `(` `concurrent` `)`
574 /// ordered ::= `ordered` `(` ssa-id-and-type `)`
575 /// inclusive ::= `inclusive`
576 ///
577 /// Note that each clause can only appear once in the clase-list.
578 static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
579                                 SmallVectorImpl<ClauseType> &clauses,
580                                 SmallVectorImpl<int> &segments) {
581 
582   // Check done[clause] to see if it has been parsed already
583   llvm::BitVector done(ClauseType::COUNT, false);
584 
585   // See pos[clause] to get position of clause in operand segments
586   SmallVector<int> pos(ClauseType::COUNT, -1);
587 
588   // Stores the last parsed clause keyword
589   StringRef clauseKeyword;
590   StringRef opName = result.name.getStringRef();
591 
592   // Containers for storing operands, types and attributes for various clauses
593   std::pair<OpAsmParser::OperandType, Type> ifCond;
594   std::pair<OpAsmParser::OperandType, Type> numThreads;
595 
596   SmallVector<OpAsmParser::OperandType> privates, firstprivates, lastprivates,
597       shareds, copyins;
598   SmallVector<Type> privateTypes, firstprivateTypes, lastprivateTypes,
599       sharedTypes, copyinTypes;
600 
601   SmallVector<OpAsmParser::OperandType> allocates, allocators;
602   SmallVector<Type> allocateTypes, allocatorTypes;
603 
604   SmallVector<SymbolRefAttr> reductionSymbols;
605   SmallVector<OpAsmParser::OperandType> reductionVars;
606   SmallVector<Type> reductionVarTypes;
607 
608   SmallVector<OpAsmParser::OperandType> linears;
609   SmallVector<Type> linearTypes;
610   SmallVector<OpAsmParser::OperandType> linearSteps;
611 
612   SmallString<8> schedule;
613   SmallVector<SmallString<12>> modifiers;
614   Optional<OpAsmParser::OperandType> scheduleChunkSize;
615 
616   // Compute the position of clauses in operand segments
617   int currPos = 0;
618   for (ClauseType clause : clauses) {
619 
620     // Skip the following clauses - they do not take any position in operand
621     // segments
622     if (clause == defaultClause || clause == procBindClause ||
623         clause == nowaitClause || clause == collapseClause ||
624         clause == orderClause || clause == orderedClause)
625       continue;
626 
627     pos[clause] = currPos++;
628 
629     // For the following clauses, two positions are reserved in the operand
630     // segments
631     if (clause == allocateClause || clause == linearClause)
632       currPos++;
633   }
634 
635   SmallVector<int> clauseSegments(currPos);
636 
637   // Helper function to check if a clause is allowed/repeated or not
638   auto checkAllowed = [&](ClauseType clause,
639                           bool allowRepeat = false) -> ParseResult {
640     if (!llvm::is_contained(clauses, clause))
641       return parser.emitError(parser.getCurrentLocation())
642              << clauseKeyword << " is not a valid clause for the " << opName
643              << " operation";
644     if (done[clause] && !allowRepeat)
645       return parser.emitError(parser.getCurrentLocation())
646              << "at most one " << clauseKeyword << " clause can appear on the "
647              << opName << " operation";
648     done[clause] = true;
649     return success();
650   };
651 
652   while (succeeded(parser.parseOptionalKeyword(&clauseKeyword))) {
653     if (clauseKeyword == "if") {
654       if (checkAllowed(ifClause) || parser.parseLParen() ||
655           parser.parseOperand(ifCond.first) ||
656           parser.parseColonType(ifCond.second) || parser.parseRParen())
657         return failure();
658       clauseSegments[pos[ifClause]] = 1;
659     } else if (clauseKeyword == "num_threads") {
660       if (checkAllowed(numThreadsClause) || parser.parseLParen() ||
661           parser.parseOperand(numThreads.first) ||
662           parser.parseColonType(numThreads.second) || parser.parseRParen())
663         return failure();
664       clauseSegments[pos[numThreadsClause]] = 1;
665     } else if (clauseKeyword == "private") {
666       if (checkAllowed(privateClause) ||
667           parseOperandAndTypeList(parser, privates, privateTypes))
668         return failure();
669       clauseSegments[pos[privateClause]] = privates.size();
670     } else if (clauseKeyword == "firstprivate") {
671       if (checkAllowed(firstprivateClause) ||
672           parseOperandAndTypeList(parser, firstprivates, firstprivateTypes))
673         return failure();
674       clauseSegments[pos[firstprivateClause]] = firstprivates.size();
675     } else if (clauseKeyword == "lastprivate") {
676       if (checkAllowed(lastprivateClause) ||
677           parseOperandAndTypeList(parser, lastprivates, lastprivateTypes))
678         return failure();
679       clauseSegments[pos[lastprivateClause]] = lastprivates.size();
680     } else if (clauseKeyword == "shared") {
681       if (checkAllowed(sharedClause) ||
682           parseOperandAndTypeList(parser, shareds, sharedTypes))
683         return failure();
684       clauseSegments[pos[sharedClause]] = shareds.size();
685     } else if (clauseKeyword == "copyin") {
686       if (checkAllowed(copyinClause) ||
687           parseOperandAndTypeList(parser, copyins, copyinTypes))
688         return failure();
689       clauseSegments[pos[copyinClause]] = copyins.size();
690     } else if (clauseKeyword == "allocate") {
691       if (checkAllowed(allocateClause) ||
692           parseAllocateAndAllocator(parser, allocates, allocateTypes,
693                                     allocators, allocatorTypes))
694         return failure();
695       clauseSegments[pos[allocateClause]] = allocates.size();
696       clauseSegments[pos[allocateClause] + 1] = allocators.size();
697     } else if (clauseKeyword == "default") {
698       StringRef defval;
699       if (checkAllowed(defaultClause) || parser.parseLParen() ||
700           parser.parseKeyword(&defval) || parser.parseRParen())
701         return failure();
702       // The def prefix is required for the attribute as "private" is a keyword
703       // in C++.
704       auto attr = parser.getBuilder().getStringAttr("def" + defval);
705       result.addAttribute("default_val", attr);
706     } else if (clauseKeyword == "proc_bind") {
707       StringRef bind;
708       if (checkAllowed(procBindClause) || parser.parseLParen() ||
709           parser.parseKeyword(&bind) || parser.parseRParen())
710         return failure();
711       auto attr = parser.getBuilder().getStringAttr(bind);
712       result.addAttribute("proc_bind_val", attr);
713     } else if (clauseKeyword == "reduction") {
714       if (checkAllowed(reductionClause) ||
715           parseReductionVarList(parser, reductionSymbols, reductionVars,
716                                 reductionVarTypes))
717         return failure();
718       clauseSegments[pos[reductionClause]] = reductionVars.size();
719     } else if (clauseKeyword == "nowait") {
720       if (checkAllowed(nowaitClause))
721         return failure();
722       auto attr = UnitAttr::get(parser.getBuilder().getContext());
723       result.addAttribute("nowait", attr);
724     } else if (clauseKeyword == "linear") {
725       if (checkAllowed(linearClause) ||
726           parseLinearClause(parser, linears, linearTypes, linearSteps))
727         return failure();
728       clauseSegments[pos[linearClause]] = linears.size();
729       clauseSegments[pos[linearClause] + 1] = linearSteps.size();
730     } else if (clauseKeyword == "schedule") {
731       if (checkAllowed(scheduleClause) ||
732           parseScheduleClause(parser, schedule, modifiers, scheduleChunkSize))
733         return failure();
734       if (scheduleChunkSize) {
735         clauseSegments[pos[scheduleClause]] = 1;
736       }
737     } else if (clauseKeyword == "collapse") {
738       auto type = parser.getBuilder().getI64Type();
739       mlir::IntegerAttr attr;
740       if (checkAllowed(collapseClause) || parser.parseLParen() ||
741           parser.parseAttribute(attr, type) || parser.parseRParen())
742         return failure();
743       result.addAttribute("collapse_val", attr);
744     } else if (clauseKeyword == "ordered") {
745       mlir::IntegerAttr attr;
746       if (checkAllowed(orderedClause))
747         return failure();
748       if (succeeded(parser.parseOptionalLParen())) {
749         auto type = parser.getBuilder().getI64Type();
750         if (parser.parseAttribute(attr, type) || parser.parseRParen())
751           return failure();
752       } else {
753         // Use 0 to represent no ordered parameter was specified
754         attr = parser.getBuilder().getI64IntegerAttr(0);
755       }
756       result.addAttribute("ordered_val", attr);
757     } else if (clauseKeyword == "order") {
758       StringRef order;
759       if (checkAllowed(orderClause) || parser.parseLParen() ||
760           parser.parseKeyword(&order) || parser.parseRParen())
761         return failure();
762       auto attr = parser.getBuilder().getStringAttr(order);
763       result.addAttribute("order_val", attr);
764     } else if (clauseKeyword == "memory_order") {
765       StringRef memoryOrder;
766       if (checkAllowed(memoryOrderClause) || parser.parseLParen() ||
767           parser.parseKeyword(&memoryOrder) || parser.parseRParen())
768         return failure();
769       result.addAttribute("memory_order",
770                           parser.getBuilder().getStringAttr(memoryOrder));
771     } else if (clauseKeyword == "hint") {
772       IntegerAttr hint;
773       if (checkAllowed(hintClause) ||
774           parseSynchronizationHint(parser, hint, false))
775         return failure();
776       result.addAttribute("hint", hint);
777     } else {
778       return parser.emitError(parser.getNameLoc())
779              << clauseKeyword << " is not a valid clause";
780     }
781   }
782 
783   // Add if parameter.
784   if (done[ifClause] && clauseSegments[pos[ifClause]] &&
785       failed(
786           parser.resolveOperand(ifCond.first, ifCond.second, result.operands)))
787     return failure();
788 
789   // Add num_threads parameter.
790   if (done[numThreadsClause] && clauseSegments[pos[numThreadsClause]] &&
791       failed(parser.resolveOperand(numThreads.first, numThreads.second,
792                                    result.operands)))
793     return failure();
794 
795   // Add private parameters.
796   if (done[privateClause] && clauseSegments[pos[privateClause]] &&
797       failed(parser.resolveOperands(privates, privateTypes,
798                                     privates[0].location, result.operands)))
799     return failure();
800 
801   // Add firstprivate parameters.
802   if (done[firstprivateClause] && clauseSegments[pos[firstprivateClause]] &&
803       failed(parser.resolveOperands(firstprivates, firstprivateTypes,
804                                     firstprivates[0].location,
805                                     result.operands)))
806     return failure();
807 
808   // Add lastprivate parameters.
809   if (done[lastprivateClause] && clauseSegments[pos[lastprivateClause]] &&
810       failed(parser.resolveOperands(lastprivates, lastprivateTypes,
811                                     lastprivates[0].location, result.operands)))
812     return failure();
813 
814   // Add shared parameters.
815   if (done[sharedClause] && clauseSegments[pos[sharedClause]] &&
816       failed(parser.resolveOperands(shareds, sharedTypes, shareds[0].location,
817                                     result.operands)))
818     return failure();
819 
820   // Add copyin parameters.
821   if (done[copyinClause] && clauseSegments[pos[copyinClause]] &&
822       failed(parser.resolveOperands(copyins, copyinTypes, copyins[0].location,
823                                     result.operands)))
824     return failure();
825 
826   // Add allocate parameters.
827   if (done[allocateClause] && clauseSegments[pos[allocateClause]] &&
828       failed(parser.resolveOperands(allocates, allocateTypes,
829                                     allocates[0].location, result.operands)))
830     return failure();
831 
832   // Add allocator parameters.
833   if (done[allocateClause] && clauseSegments[pos[allocateClause] + 1] &&
834       failed(parser.resolveOperands(allocators, allocatorTypes,
835                                     allocators[0].location, result.operands)))
836     return failure();
837 
838   // Add reduction parameters and symbols
839   if (done[reductionClause] && clauseSegments[pos[reductionClause]]) {
840     if (failed(parser.resolveOperands(reductionVars, reductionVarTypes,
841                                       parser.getNameLoc(), result.operands)))
842       return failure();
843 
844     SmallVector<Attribute> reductions(reductionSymbols.begin(),
845                                       reductionSymbols.end());
846     result.addAttribute("reductions",
847                         parser.getBuilder().getArrayAttr(reductions));
848   }
849 
850   // Add linear parameters
851   if (done[linearClause] && clauseSegments[pos[linearClause]]) {
852     auto linearStepType = parser.getBuilder().getI32Type();
853     SmallVector<Type> linearStepTypes(linearSteps.size(), linearStepType);
854     if (failed(parser.resolveOperands(linears, linearTypes, linears[0].location,
855                                       result.operands)) ||
856         failed(parser.resolveOperands(linearSteps, linearStepTypes,
857                                       linearSteps[0].location,
858                                       result.operands)))
859       return failure();
860   }
861 
862   // Add schedule parameters
863   if (done[scheduleClause] && !schedule.empty()) {
864     schedule[0] = llvm::toUpper(schedule[0]);
865     auto attr = parser.getBuilder().getStringAttr(schedule);
866     result.addAttribute("schedule_val", attr);
867     if (!modifiers.empty()) {
868       auto mod = parser.getBuilder().getStringAttr(modifiers[0]);
869       result.addAttribute("schedule_modifier", mod);
870       // Only SIMD attribute is allowed here!
871       if (modifiers.size() > 1) {
872         assert(symbolizeScheduleModifier(modifiers[1]) ==
873                mlir::omp::ScheduleModifier::simd);
874         auto attr = UnitAttr::get(parser.getBuilder().getContext());
875         result.addAttribute("simd_modifier", attr);
876       }
877     }
878     if (scheduleChunkSize) {
879       auto chunkSizeType = parser.getBuilder().getI32Type();
880       parser.resolveOperand(*scheduleChunkSize, chunkSizeType, result.operands);
881     }
882   }
883 
884   segments.insert(segments.end(), clauseSegments.begin(), clauseSegments.end());
885 
886   return success();
887 }
888 
889 /// Parses a parallel operation.
890 ///
891 /// operation ::= `omp.parallel` clause-list
892 /// clause-list ::= clause | clause clause-list
893 /// clause ::= if | num-threads | private | firstprivate | shared | copyin |
894 ///            allocate | default | proc-bind
895 ///
896 static ParseResult parseParallelOp(OpAsmParser &parser,
897                                    OperationState &result) {
898   SmallVector<ClauseType> clauses = {
899       ifClause,           numThreadsClause, privateClause,
900       firstprivateClause, sharedClause,     copyinClause,
901       allocateClause,     defaultClause,    procBindClause};
902 
903   SmallVector<int> segments;
904 
905   if (failed(parseClauses(parser, result, clauses, segments)))
906     return failure();
907 
908   result.addAttribute("operand_segment_sizes",
909                       parser.getBuilder().getI32VectorAttr(segments));
910 
911   Region *body = result.addRegion();
912   SmallVector<OpAsmParser::OperandType> regionArgs;
913   SmallVector<Type> regionArgTypes;
914   if (parser.parseRegion(*body, regionArgs, regionArgTypes))
915     return failure();
916   return success();
917 }
918 
919 //===----------------------------------------------------------------------===//
920 // Parser, printer and verifier for SectionsOp
921 //===----------------------------------------------------------------------===//
922 
923 /// Parses an OpenMP Sections operation
924 ///
925 /// sections ::= `omp.sections` clause-list
926 /// clause-list ::= clause clause-list | empty
927 /// clause ::= private | firstprivate | lastprivate | reduction | allocate |
928 ///            nowait
929 static ParseResult parseSectionsOp(OpAsmParser &parser,
930                                    OperationState &result) {
931 
932   SmallVector<ClauseType> clauses = {privateClause,     firstprivateClause,
933                                      lastprivateClause, reductionClause,
934                                      allocateClause,    nowaitClause};
935 
936   SmallVector<int> segments;
937 
938   if (failed(parseClauses(parser, result, clauses, segments)))
939     return failure();
940 
941   result.addAttribute("operand_segment_sizes",
942                       parser.getBuilder().getI32VectorAttr(segments));
943 
944   // Now parse the body.
945   Region *body = result.addRegion();
946   if (parser.parseRegion(*body))
947     return failure();
948   return success();
949 }
950 
951 static void printSectionsOp(OpAsmPrinter &p, SectionsOp op) {
952   p << " ";
953   printDataVars(p, op.private_vars(), "private");
954   printDataVars(p, op.firstprivate_vars(), "firstprivate");
955   printDataVars(p, op.lastprivate_vars(), "lastprivate");
956 
957   if (!op.reduction_vars().empty())
958     printReductionVarList(p, op.reductions(), op.reduction_vars());
959 
960   if (!op.allocate_vars().empty())
961     printAllocateAndAllocator(p, op.allocate_vars(), op.allocators_vars());
962 
963   if (op.nowait())
964     p << "nowait";
965 
966   p << ' ';
967   p.printRegion(op.region());
968 }
969 
970 static LogicalResult verifySectionsOp(SectionsOp op) {
971 
972   // A list item may not appear in more than one clause on the same directive,
973   // except that it may be specified in both firstprivate and lastprivate
974   // clauses.
975   for (auto var : op.private_vars()) {
976     if (llvm::is_contained(op.firstprivate_vars(), var))
977       return op.emitOpError()
978              << "operand used in both private and firstprivate clauses";
979     if (llvm::is_contained(op.lastprivate_vars(), var))
980       return op.emitOpError()
981              << "operand used in both private and lastprivate clauses";
982   }
983 
984   if (op.allocate_vars().size() != op.allocators_vars().size())
985     return op.emitError(
986         "expected equal sizes for allocate and allocator variables");
987 
988   for (auto &inst : *op.region().begin()) {
989     if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst)))
990       op.emitOpError()
991           << "expected omp.section op or terminator op inside region";
992   }
993 
994   return verifyReductionVarList(op, op.reductions(), op.reduction_vars());
995 }
996 
997 /// Parses an OpenMP Workshare Loop operation
998 ///
999 /// wsloop ::= `omp.wsloop` loop-control clause-list
1000 /// loop-control ::= `(` ssa-id-list `)` `:` type `=`  loop-bounds
1001 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` inclusive? steps
1002 /// steps := `step` `(`ssa-id-list`)`
1003 /// clause-list ::= clause clause-list | empty
1004 /// clause ::= private | firstprivate | lastprivate | linear | schedule |
1005 //             collapse | nowait | ordered | order | reduction
1006 static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) {
1007 
1008   // Parse an opening `(` followed by induction variables followed by `)`
1009   SmallVector<OpAsmParser::OperandType> ivs;
1010   if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
1011                                      OpAsmParser::Delimiter::Paren))
1012     return failure();
1013 
1014   int numIVs = static_cast<int>(ivs.size());
1015   Type loopVarType;
1016   if (parser.parseColonType(loopVarType))
1017     return failure();
1018 
1019   // Parse loop bounds.
1020   SmallVector<OpAsmParser::OperandType> lower;
1021   if (parser.parseEqual() ||
1022       parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) ||
1023       parser.resolveOperands(lower, loopVarType, result.operands))
1024     return failure();
1025 
1026   SmallVector<OpAsmParser::OperandType> upper;
1027   if (parser.parseKeyword("to") ||
1028       parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) ||
1029       parser.resolveOperands(upper, loopVarType, result.operands))
1030     return failure();
1031 
1032   if (succeeded(parser.parseOptionalKeyword("inclusive"))) {
1033     auto attr = UnitAttr::get(parser.getBuilder().getContext());
1034     result.addAttribute("inclusive", attr);
1035   }
1036 
1037   // Parse step values.
1038   SmallVector<OpAsmParser::OperandType> steps;
1039   if (parser.parseKeyword("step") ||
1040       parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) ||
1041       parser.resolveOperands(steps, loopVarType, result.operands))
1042     return failure();
1043 
1044   SmallVector<ClauseType> clauses = {
1045       privateClause,   firstprivateClause, lastprivateClause, linearClause,
1046       reductionClause, collapseClause,     orderClause,       orderedClause,
1047       nowaitClause,    scheduleClause};
1048   SmallVector<int> segments{numIVs, numIVs, numIVs};
1049   if (failed(parseClauses(parser, result, clauses, segments)))
1050     return failure();
1051 
1052   result.addAttribute("operand_segment_sizes",
1053                       parser.getBuilder().getI32VectorAttr(segments));
1054 
1055   // Now parse the body.
1056   Region *body = result.addRegion();
1057   SmallVector<Type> ivTypes(numIVs, loopVarType);
1058   SmallVector<OpAsmParser::OperandType> blockArgs(ivs);
1059   if (parser.parseRegion(*body, blockArgs, ivTypes))
1060     return failure();
1061   return success();
1062 }
1063 
1064 static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) {
1065   auto args = op.getRegion().front().getArguments();
1066   p << " (" << args << ") : " << args[0].getType() << " = (" << op.lowerBound()
1067     << ") to (" << op.upperBound() << ") ";
1068   if (op.inclusive()) {
1069     p << "inclusive ";
1070   }
1071   p << "step (" << op.step() << ") ";
1072 
1073   printDataVars(p, op.private_vars(), "private");
1074   printDataVars(p, op.firstprivate_vars(), "firstprivate");
1075   printDataVars(p, op.lastprivate_vars(), "lastprivate");
1076 
1077   if (!op.linear_vars().empty())
1078     printLinearClause(p, op.linear_vars(), op.linear_step_vars());
1079 
1080   if (auto sched = op.schedule_val())
1081     printScheduleClause(p, sched.getValue(), op.schedule_modifier(),
1082                         op.simd_modifier(), op.schedule_chunk_var());
1083 
1084   if (auto collapse = op.collapse_val())
1085     p << "collapse(" << collapse << ") ";
1086 
1087   if (op.nowait())
1088     p << "nowait ";
1089 
1090   if (auto ordered = op.ordered_val())
1091     p << "ordered(" << ordered << ") ";
1092 
1093   if (auto order = op.order_val())
1094     p << "order(" << order << ") ";
1095 
1096   if (!op.reduction_vars().empty())
1097     printReductionVarList(p, op.reductions(), op.reduction_vars());
1098 
1099   p << ' ';
1100   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
1101 }
1102 
1103 //===----------------------------------------------------------------------===//
1104 // ReductionOp
1105 //===----------------------------------------------------------------------===//
1106 
1107 static ParseResult parseAtomicReductionRegion(OpAsmParser &parser,
1108                                               Region &region) {
1109   if (parser.parseOptionalKeyword("atomic"))
1110     return success();
1111   return parser.parseRegion(region);
1112 }
1113 
1114 static void printAtomicReductionRegion(OpAsmPrinter &printer,
1115                                        ReductionDeclareOp op, Region &region) {
1116   if (region.empty())
1117     return;
1118   printer << "atomic ";
1119   printer.printRegion(region);
1120 }
1121 
1122 static LogicalResult verifyReductionDeclareOp(ReductionDeclareOp op) {
1123   if (op.initializerRegion().empty())
1124     return op.emitOpError() << "expects non-empty initializer region";
1125   Block &initializerEntryBlock = op.initializerRegion().front();
1126   if (initializerEntryBlock.getNumArguments() != 1 ||
1127       initializerEntryBlock.getArgument(0).getType() != op.type()) {
1128     return op.emitOpError() << "expects initializer region with one argument "
1129                                "of the reduction type";
1130   }
1131 
1132   for (YieldOp yieldOp : op.initializerRegion().getOps<YieldOp>()) {
1133     if (yieldOp.results().size() != 1 ||
1134         yieldOp.results().getTypes()[0] != op.type())
1135       return op.emitOpError() << "expects initializer region to yield a value "
1136                                  "of the reduction type";
1137   }
1138 
1139   if (op.reductionRegion().empty())
1140     return op.emitOpError() << "expects non-empty reduction region";
1141   Block &reductionEntryBlock = op.reductionRegion().front();
1142   if (reductionEntryBlock.getNumArguments() != 2 ||
1143       reductionEntryBlock.getArgumentTypes()[0] !=
1144           reductionEntryBlock.getArgumentTypes()[1] ||
1145       reductionEntryBlock.getArgumentTypes()[0] != op.type())
1146     return op.emitOpError() << "expects reduction region with two arguments of "
1147                                "the reduction type";
1148   for (YieldOp yieldOp : op.reductionRegion().getOps<YieldOp>()) {
1149     if (yieldOp.results().size() != 1 ||
1150         yieldOp.results().getTypes()[0] != op.type())
1151       return op.emitOpError() << "expects reduction region to yield a value "
1152                                  "of the reduction type";
1153   }
1154 
1155   if (op.atomicReductionRegion().empty())
1156     return success();
1157 
1158   Block &atomicReductionEntryBlock = op.atomicReductionRegion().front();
1159   if (atomicReductionEntryBlock.getNumArguments() != 2 ||
1160       atomicReductionEntryBlock.getArgumentTypes()[0] !=
1161           atomicReductionEntryBlock.getArgumentTypes()[1])
1162     return op.emitOpError() << "expects atomic reduction region with two "
1163                                "arguments of the same type";
1164   auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0]
1165                      .dyn_cast<PointerLikeType>();
1166   if (!ptrType || ptrType.getElementType() != op.type())
1167     return op.emitOpError() << "expects atomic reduction region arguments to "
1168                                "be accumulators containing the reduction type";
1169   return success();
1170 }
1171 
1172 static LogicalResult verifyReductionOp(ReductionOp op) {
1173   // TODO: generalize this to an op interface when there is more than one op
1174   // that supports reductions.
1175   auto container = op->getParentOfType<WsLoopOp>();
1176   for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i)
1177     if (container.reduction_vars()[i] == op.accumulator())
1178       return success();
1179 
1180   return op.emitOpError() << "the accumulator is not used by the parent";
1181 }
1182 
1183 //===----------------------------------------------------------------------===//
1184 // WsLoopOp
1185 //===----------------------------------------------------------------------===//
1186 
1187 void WsLoopOp::build(OpBuilder &builder, OperationState &state,
1188                      ValueRange lowerBound, ValueRange upperBound,
1189                      ValueRange step, ArrayRef<NamedAttribute> attributes) {
1190   build(builder, state, TypeRange(), lowerBound, upperBound, step,
1191         /*privateVars=*/ValueRange(),
1192         /*firstprivateVars=*/ValueRange(), /*lastprivate_vars=*/ValueRange(),
1193         /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
1194         /*reduction_vars=*/ValueRange(), /*schedule_val=*/nullptr,
1195         /*schedule_chunk_var=*/nullptr, /*collapse_val=*/nullptr,
1196         /*nowait=*/nullptr, /*ordered_val=*/nullptr, /*order_val=*/nullptr,
1197         /*inclusive=*/nullptr, /*buildBody=*/false);
1198   state.addAttributes(attributes);
1199 }
1200 
1201 void WsLoopOp::build(OpBuilder &, OperationState &state, TypeRange resultTypes,
1202                      ValueRange operands, ArrayRef<NamedAttribute> attributes) {
1203   state.addOperands(operands);
1204   state.addAttributes(attributes);
1205   (void)state.addRegion();
1206   assert(resultTypes.empty() && "mismatched number of return types");
1207   state.addTypes(resultTypes);
1208 }
1209 
1210 void WsLoopOp::build(OpBuilder &builder, OperationState &result,
1211                      TypeRange typeRange, ValueRange lowerBounds,
1212                      ValueRange upperBounds, ValueRange steps,
1213                      ValueRange privateVars, ValueRange firstprivateVars,
1214                      ValueRange lastprivateVars, ValueRange linearVars,
1215                      ValueRange linearStepVars, ValueRange reductionVars,
1216                      StringAttr scheduleVal, Value scheduleChunkVar,
1217                      IntegerAttr collapseVal, UnitAttr nowait,
1218                      IntegerAttr orderedVal, StringAttr orderVal,
1219                      UnitAttr inclusive, bool buildBody) {
1220   result.addOperands(lowerBounds);
1221   result.addOperands(upperBounds);
1222   result.addOperands(steps);
1223   result.addOperands(privateVars);
1224   result.addOperands(firstprivateVars);
1225   result.addOperands(linearVars);
1226   result.addOperands(linearStepVars);
1227   if (scheduleChunkVar)
1228     result.addOperands(scheduleChunkVar);
1229 
1230   if (scheduleVal)
1231     result.addAttribute("schedule_val", scheduleVal);
1232   if (collapseVal)
1233     result.addAttribute("collapse_val", collapseVal);
1234   if (nowait)
1235     result.addAttribute("nowait", nowait);
1236   if (orderedVal)
1237     result.addAttribute("ordered_val", orderedVal);
1238   if (orderVal)
1239     result.addAttribute("order", orderVal);
1240   if (inclusive)
1241     result.addAttribute("inclusive", inclusive);
1242   result.addAttribute(
1243       WsLoopOp::getOperandSegmentSizeAttr(),
1244       builder.getI32VectorAttr(
1245           {static_cast<int32_t>(lowerBounds.size()),
1246            static_cast<int32_t>(upperBounds.size()),
1247            static_cast<int32_t>(steps.size()),
1248            static_cast<int32_t>(privateVars.size()),
1249            static_cast<int32_t>(firstprivateVars.size()),
1250            static_cast<int32_t>(lastprivateVars.size()),
1251            static_cast<int32_t>(linearVars.size()),
1252            static_cast<int32_t>(linearStepVars.size()),
1253            static_cast<int32_t>(reductionVars.size()),
1254            static_cast<int32_t>(scheduleChunkVar != nullptr ? 1 : 0)}));
1255 
1256   Region *bodyRegion = result.addRegion();
1257   if (buildBody) {
1258     OpBuilder::InsertionGuard guard(builder);
1259     unsigned numIVs = steps.size();
1260     SmallVector<Type, 8> argTypes(numIVs, steps.getType().front());
1261     builder.createBlock(bodyRegion, {}, argTypes);
1262   }
1263 }
1264 
1265 static LogicalResult verifyWsLoopOp(WsLoopOp op) {
1266   return verifyReductionVarList(op, op.reductions(), op.reduction_vars());
1267 }
1268 
1269 //===----------------------------------------------------------------------===//
1270 // Verifier for critical construct (2.17.1)
1271 //===----------------------------------------------------------------------===//
1272 
1273 static LogicalResult verifyCriticalDeclareOp(CriticalDeclareOp op) {
1274   return verifySynchronizationHint(op, op.hint());
1275 }
1276 
1277 static LogicalResult verifyCriticalOp(CriticalOp op) {
1278 
1279   if (op.nameAttr()) {
1280     auto symbolRef = op.nameAttr().cast<SymbolRefAttr>();
1281     auto decl =
1282         SymbolTable::lookupNearestSymbolFrom<CriticalDeclareOp>(op, symbolRef);
1283     if (!decl) {
1284       return op.emitOpError() << "expected symbol reference " << symbolRef
1285                               << " to point to a critical declaration";
1286     }
1287   }
1288 
1289   return success();
1290 }
1291 
1292 //===----------------------------------------------------------------------===//
1293 // Verifier for ordered construct
1294 //===----------------------------------------------------------------------===//
1295 
1296 static LogicalResult verifyOrderedOp(OrderedOp op) {
1297   auto container = op->getParentOfType<WsLoopOp>();
1298   if (!container || !container.ordered_valAttr() ||
1299       container.ordered_valAttr().getInt() == 0)
1300     return op.emitOpError() << "ordered depend directive must be closely "
1301                             << "nested inside a worksharing-loop with ordered "
1302                             << "clause with parameter present";
1303 
1304   if (container.ordered_valAttr().getInt() !=
1305       (int64_t)op.num_loops_val().getValue())
1306     return op.emitOpError() << "number of variables in depend clause does not "
1307                             << "match number of iteration variables in the "
1308                             << "doacross loop";
1309 
1310   return success();
1311 }
1312 
1313 static LogicalResult verifyOrderedRegionOp(OrderedRegionOp op) {
1314   // TODO: The code generation for ordered simd directive is not supported yet.
1315   if (op.simd())
1316     return failure();
1317 
1318   if (auto container = op->getParentOfType<WsLoopOp>()) {
1319     if (!container.ordered_valAttr() ||
1320         container.ordered_valAttr().getInt() != 0)
1321       return op.emitOpError() << "ordered region must be closely nested inside "
1322                               << "a worksharing-loop region with an ordered "
1323                               << "clause without parameter present";
1324   }
1325 
1326   return success();
1327 }
1328 
1329 //===----------------------------------------------------------------------===//
1330 // AtomicReadOp
1331 //===----------------------------------------------------------------------===//
1332 
1333 /// Parser for AtomicReadOp
1334 ///
1335 /// operation ::= `omp.atomic.read` atomic-clause-list address `->` result-type
1336 /// address ::= operand `:` type
1337 static ParseResult parseAtomicReadOp(OpAsmParser &parser,
1338                                      OperationState &result) {
1339   OpAsmParser::OperandType x, v;
1340   Type addressType;
1341   SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
1342   SmallVector<int> segments;
1343 
1344   if (parser.parseOperand(v) || parser.parseEqual() || parser.parseOperand(x) ||
1345       parseClauses(parser, result, clauses, segments) ||
1346       parser.parseColonType(addressType) ||
1347       parser.resolveOperand(x, addressType, result.operands) ||
1348       parser.resolveOperand(v, addressType, result.operands))
1349     return failure();
1350 
1351   return success();
1352 }
1353 
1354 /// Printer for AtomicReadOp
1355 static void printAtomicReadOp(OpAsmPrinter &p, AtomicReadOp op) {
1356   p << " " << op.v() << " = " << op.x() << " ";
1357   if (op.memory_order())
1358     p << "memory_order(" << op.memory_order().getValue() << ") ";
1359   if (op.hintAttr())
1360     printSynchronizationHint(p << " ", op, op.hintAttr());
1361   p << ": " << op.x().getType();
1362 }
1363 
1364 /// Verifier for AtomicReadOp
1365 static LogicalResult verifyAtomicReadOp(AtomicReadOp op) {
1366   if (op.memory_order()) {
1367     StringRef memOrder = op.memory_order().getValue();
1368     if (memOrder.equals("acq_rel") || memOrder.equals("release"))
1369       return op.emitError(
1370           "memory-order must not be acq_rel or release for atomic reads");
1371   }
1372   if (op.x() == op.v())
1373     return op.emitError(
1374         "read and write must not be to the same location for atomic reads");
1375   return verifySynchronizationHint(op, op.hint());
1376 }
1377 
1378 //===----------------------------------------------------------------------===//
1379 // AtomicWriteOp
1380 //===----------------------------------------------------------------------===//
1381 
1382 /// Parser for AtomicWriteOp
1383 ///
1384 /// operation ::= `omp.atomic.write` atomic-clause-list operands
1385 /// operands ::= address `,` value
1386 /// address ::= operand `:` type
1387 /// value ::= operand `:` type
1388 static ParseResult parseAtomicWriteOp(OpAsmParser &parser,
1389                                       OperationState &result) {
1390   OpAsmParser::OperandType address, value;
1391   Type addrType, valueType;
1392   SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
1393   SmallVector<int> segments;
1394 
1395   if (parser.parseOperand(address) || parser.parseEqual() ||
1396       parser.parseOperand(value) ||
1397       parseClauses(parser, result, clauses, segments) ||
1398       parser.parseColonType(addrType) || parser.parseComma() ||
1399       parser.parseType(valueType) ||
1400       parser.resolveOperand(address, addrType, result.operands) ||
1401       parser.resolveOperand(value, valueType, result.operands))
1402     return failure();
1403   return success();
1404 }
1405 
1406 /// Printer for AtomicWriteOp
1407 static void printAtomicWriteOp(OpAsmPrinter &p, AtomicWriteOp op) {
1408   p << " " << op.address() << " = " << op.value() << " ";
1409   if (op.memory_order())
1410     p << "memory_order(" << op.memory_order() << ") ";
1411   if (op.hintAttr())
1412     printSynchronizationHint(p, op, op.hintAttr());
1413   p << ": " << op.address().getType() << ", " << op.value().getType();
1414 }
1415 
1416 /// Verifier for AtomicWriteOp
1417 static LogicalResult verifyAtomicWriteOp(AtomicWriteOp op) {
1418   if (op.memory_order()) {
1419     StringRef memoryOrder = op.memory_order().getValue();
1420     if (memoryOrder.equals("acq_rel") || memoryOrder.equals("acquire"))
1421       return op.emitError(
1422           "memory-order must not be acq_rel or acquire for atomic writes");
1423   }
1424   return verifySynchronizationHint(op, op.hint());
1425 }
1426 
1427 //===----------------------------------------------------------------------===//
1428 // AtomicUpdateOp
1429 //===----------------------------------------------------------------------===//
1430 
1431 /// Parser for AtomicUpdateOp
1432 ///
1433 /// operation ::= `omp.atomic.update` atomic-clause-list region
1434 static ParseResult parseAtomicUpdateOp(OpAsmParser &parser,
1435                                        OperationState &result) {
1436   SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
1437   SmallVector<int> segments;
1438   OpAsmParser::OperandType x, y, z;
1439   Type xType, exprType;
1440   StringRef binOp;
1441 
1442   // x = y `op` z : xtype, exprtype
1443   if (parser.parseOperand(x) || parser.parseEqual() || parser.parseOperand(y) ||
1444       parser.parseKeyword(&binOp) || parser.parseOperand(z) ||
1445       parseClauses(parser, result, clauses, segments) || parser.parseColon() ||
1446       parser.parseType(xType) || parser.parseComma() ||
1447       parser.parseType(exprType) ||
1448       parser.resolveOperand(x, xType, result.operands)) {
1449     return failure();
1450   }
1451 
1452   auto binOpEnum = AtomicBinOpKindToEnum(binOp.upper());
1453   if (!binOpEnum)
1454     return parser.emitError(parser.getNameLoc())
1455            << "invalid atomic bin op in atomic update\n";
1456   auto attr =
1457       parser.getBuilder().getI64IntegerAttr((int64_t)binOpEnum.getValue());
1458   result.addAttribute("binop", attr);
1459 
1460   OpAsmParser::OperandType expr;
1461   if (x.name == y.name && x.number == y.number) {
1462     expr = z;
1463     result.addAttribute("isXBinopExpr", parser.getBuilder().getUnitAttr());
1464   } else if (x.name == z.name && x.number == z.number) {
1465     expr = y;
1466   } else {
1467     return parser.emitError(parser.getNameLoc())
1468            << "atomic update variable " << x.name
1469            << " not found in the RHS of the assignment statement in an"
1470               " atomic.update operation";
1471   }
1472   return parser.resolveOperand(expr, exprType, result.operands);
1473 }
1474 
1475 /// Printer for AtomicUpdateOp
1476 static void printAtomicUpdateOp(OpAsmPrinter &p, AtomicUpdateOp op) {
1477   p << " " << op.x() << " = ";
1478   Value y, z;
1479   if (op.isXBinopExpr()) {
1480     y = op.x();
1481     z = op.expr();
1482   } else {
1483     y = op.expr();
1484     z = op.x();
1485   }
1486   p << y << " " << AtomicBinOpKindToString(op.binop()).lower() << " " << z
1487     << " ";
1488   if (op.memory_order())
1489     p << "memory_order(" << op.memory_order() << ") ";
1490   if (op.hintAttr())
1491     printSynchronizationHint(p, op, op.hintAttr());
1492   p << ": " << op.x().getType() << ", " << op.expr().getType();
1493 }
1494 
1495 /// Verifier for AtomicUpdateOp
1496 static LogicalResult verifyAtomicUpdateOp(AtomicUpdateOp op) {
1497   if (op.memory_order()) {
1498     StringRef memoryOrder = op.memory_order().getValue();
1499     if (memoryOrder.equals("acq_rel") || memoryOrder.equals("acquire"))
1500       return op.emitError(
1501           "memory-order must not be acq_rel or acquire for atomic updates");
1502   }
1503   return success();
1504 }
1505 
1506 #define GET_OP_CLASSES
1507 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
1508