1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the OpenMP dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
15 #include "mlir/IR/Attributes.h"
16 #include "mlir/IR/DialectImplementation.h"
17 #include "mlir/IR/OpImplementation.h"
18 #include "mlir/IR/OperationSupport.h"
19 
20 #include "llvm/ADT/BitVector.h"
21 #include "llvm/ADT/SmallString.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/ADT/StringSwitch.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include <cstddef>
27 
28 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
29 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
30 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
31 
32 using namespace mlir;
33 using namespace mlir::omp;
34 
35 namespace {
36 /// Model for pointer-like types that already provide a `getElementType` method.
37 template <typename T>
38 struct PointerLikeModel
39     : public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> {
40   Type getElementType(Type pointer) const {
41     return pointer.cast<T>().getElementType();
42   }
43 };
44 } // namespace
45 
46 void OpenMPDialect::initialize() {
47   addOperations<
48 #define GET_OP_LIST
49 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
50       >();
51   addAttributes<
52 #define GET_ATTRDEF_LIST
53 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
54       >();
55 
56   LLVM::LLVMPointerType::attachInterface<
57       PointerLikeModel<LLVM::LLVMPointerType>>(*getContext());
58   MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext());
59 }
60 
61 //===----------------------------------------------------------------------===//
62 // ParallelOp
63 //===----------------------------------------------------------------------===//
64 
65 void ParallelOp::build(OpBuilder &builder, OperationState &state,
66                        ArrayRef<NamedAttribute> attributes) {
67   ParallelOp::build(
68       builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
69       /*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(),
70       /*proc_bind_val=*/nullptr);
71   state.addAttributes(attributes);
72 }
73 
74 //===----------------------------------------------------------------------===//
75 // Parser and printer for Allocate Clause
76 //===----------------------------------------------------------------------===//
77 
78 /// Parse an allocate clause with allocators and a list of operands with types.
79 ///
80 /// allocate ::= `allocate` `(` allocate-operand-list `)`
81 /// allocate-operand-list :: = allocate-operand |
82 ///                            allocator-operand `,` allocate-operand-list
83 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
84 /// ssa-id-and-type ::= ssa-id `:` type
85 static ParseResult parseAllocateAndAllocator(
86     OpAsmParser &parser,
87     SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate,
88     SmallVectorImpl<Type> &typesAllocate,
89     SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator,
90     SmallVectorImpl<Type> &typesAllocator) {
91 
92   return parser.parseCommaSeparatedList(
93       OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
94         OpAsmParser::OperandType operand;
95         Type type;
96         if (parser.parseOperand(operand) || parser.parseColonType(type))
97           return failure();
98         operandsAllocator.push_back(operand);
99         typesAllocator.push_back(type);
100         if (parser.parseArrow())
101           return failure();
102         if (parser.parseOperand(operand) || parser.parseColonType(type))
103           return failure();
104 
105         operandsAllocate.push_back(operand);
106         typesAllocate.push_back(type);
107         return success();
108       });
109 }
110 
111 /// Print allocate clause
112 static void printAllocateAndAllocator(OpAsmPrinter &p,
113                                       OperandRange varsAllocate,
114                                       OperandRange varsAllocator) {
115   p << "allocate(";
116   for (unsigned i = 0; i < varsAllocate.size(); ++i) {
117     std::string separator = i == varsAllocate.size() - 1 ? ") " : ", ";
118     p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> ";
119     p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator;
120   }
121 }
122 
123 LogicalResult ParallelOp::verify() {
124   if (allocate_vars().size() != allocators_vars().size())
125     return emitError(
126         "expected equal sizes for allocate and allocator variables");
127   return success();
128 }
129 
130 void ParallelOp::print(OpAsmPrinter &p) {
131   p << " ";
132   if (auto ifCond = if_expr_var())
133     p << "if(" << ifCond << " : " << ifCond.getType() << ") ";
134 
135   if (auto threads = num_threads_var())
136     p << "num_threads(" << threads << " : " << threads.getType() << ") ";
137 
138   if (!allocate_vars().empty())
139     printAllocateAndAllocator(p, allocate_vars(), allocators_vars());
140 
141   if (auto bind = proc_bind_val())
142     p << "proc_bind(" << stringifyClauseProcBindKind(*bind) << ") ";
143 
144   p << ' ';
145   p.printRegion(getRegion());
146 }
147 
148 void TargetOp::print(OpAsmPrinter &p) {
149   p << " ";
150   if (auto ifCond = if_expr())
151     p << "if(" << ifCond << " : " << ifCond.getType() << ") ";
152 
153   if (auto device = this->device())
154     p << "device(" << device << " : " << device.getType() << ") ";
155 
156   if (auto threads = thread_limit())
157     p << "thread_limit(" << threads << " : " << threads.getType() << ") ";
158 
159   if (nowait())
160     p << "nowait ";
161 
162   p.printRegion(getRegion());
163 }
164 
165 //===----------------------------------------------------------------------===//
166 // Parser and printer for Linear Clause
167 //===----------------------------------------------------------------------===//
168 
169 /// linear ::= `linear` `(` linear-list `)`
170 /// linear-list := linear-val | linear-val linear-list
171 /// linear-val := ssa-id-and-type `=` ssa-id-and-type
172 static ParseResult
173 parseLinearClause(OpAsmParser &parser,
174                   SmallVectorImpl<OpAsmParser::OperandType> &vars,
175                   SmallVectorImpl<Type> &types,
176                   SmallVectorImpl<OpAsmParser::OperandType> &stepVars) {
177   if (parser.parseLParen())
178     return failure();
179 
180   do {
181     OpAsmParser::OperandType var;
182     Type type;
183     OpAsmParser::OperandType stepVar;
184     if (parser.parseOperand(var) || parser.parseEqual() ||
185         parser.parseOperand(stepVar) || parser.parseColonType(type))
186       return failure();
187 
188     vars.push_back(var);
189     types.push_back(type);
190     stepVars.push_back(stepVar);
191   } while (succeeded(parser.parseOptionalComma()));
192 
193   if (parser.parseRParen())
194     return failure();
195 
196   return success();
197 }
198 
199 /// Print Linear Clause
200 static void printLinearClause(OpAsmPrinter &p, OperandRange linearVars,
201                               OperandRange linearStepVars) {
202   size_t linearVarsSize = linearVars.size();
203   p << "linear(";
204   for (unsigned i = 0; i < linearVarsSize; ++i) {
205     std::string separator = i == linearVarsSize - 1 ? ") " : ", ";
206     p << linearVars[i];
207     if (linearStepVars.size() > i)
208       p << " = " << linearStepVars[i];
209     p << " : " << linearVars[i].getType() << separator;
210   }
211 }
212 
213 //===----------------------------------------------------------------------===//
214 // Parser and printer for Schedule Clause
215 //===----------------------------------------------------------------------===//
216 
217 static ParseResult
218 verifyScheduleModifiers(OpAsmParser &parser,
219                         SmallVectorImpl<SmallString<12>> &modifiers) {
220   if (modifiers.size() > 2)
221     return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)";
222   for (const auto &mod : modifiers) {
223     // Translate the string. If it has no value, then it was not a valid
224     // modifier!
225     auto symbol = symbolizeScheduleModifier(mod);
226     if (!symbol.hasValue())
227       return parser.emitError(parser.getNameLoc())
228              << " unknown modifier type: " << mod;
229   }
230 
231   // If we have one modifier that is "simd", then stick a "none" modiifer in
232   // index 0.
233   if (modifiers.size() == 1) {
234     if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
235       modifiers.push_back(modifiers[0]);
236       modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
237     }
238   } else if (modifiers.size() == 2) {
239     // If there are two modifier:
240     // First modifier should not be simd, second one should be simd
241     if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
242         symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
243       return parser.emitError(parser.getNameLoc())
244              << " incorrect modifier order";
245   }
246   return success();
247 }
248 
249 /// schedule ::= `schedule` `(` sched-list `)`
250 /// sched-list ::= sched-val | sched-val sched-list |
251 ///                sched-val `,` sched-modifier
252 /// sched-val ::= sched-with-chunk | sched-wo-chunk
253 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
254 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
255 /// sched-wo-chunk ::=  `auto` | `runtime`
256 /// sched-modifier ::=  sched-mod-val | sched-mod-val `,` sched-mod-val
257 /// sched-mod-val ::=  `monotonic` | `nonmonotonic` | `simd` | `none`
258 static ParseResult
259 parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule,
260                     SmallVectorImpl<SmallString<12>> &modifiers,
261                     Optional<OpAsmParser::OperandType> &chunkSize,
262                     Type &chunkType) {
263   if (parser.parseLParen())
264     return failure();
265 
266   StringRef keyword;
267   if (parser.parseKeyword(&keyword))
268     return failure();
269 
270   schedule = keyword;
271   if (keyword == "static" || keyword == "dynamic" || keyword == "guided") {
272     if (succeeded(parser.parseOptionalEqual())) {
273       chunkSize = OpAsmParser::OperandType{};
274       if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType))
275         return failure();
276     } else {
277       chunkSize = llvm::NoneType::None;
278     }
279   } else if (keyword == "auto" || keyword == "runtime") {
280     chunkSize = llvm::NoneType::None;
281   } else {
282     return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
283   }
284 
285   // If there is a comma, we have one or more modifiers..
286   while (succeeded(parser.parseOptionalComma())) {
287     StringRef mod;
288     if (parser.parseKeyword(&mod))
289       return failure();
290     modifiers.push_back(mod);
291   }
292 
293   if (parser.parseRParen())
294     return failure();
295 
296   if (verifyScheduleModifiers(parser, modifiers))
297     return failure();
298 
299   return success();
300 }
301 
302 /// Print schedule clause
303 static void printScheduleClause(OpAsmPrinter &p, ClauseScheduleKind sched,
304                                 Optional<ScheduleModifier> modifier, bool simd,
305                                 Value scheduleChunkVar) {
306   p << "schedule(" << stringifyClauseScheduleKind(sched).lower();
307   if (scheduleChunkVar)
308     p << " = " << scheduleChunkVar << " : " << scheduleChunkVar.getType();
309   if (modifier)
310     p << ", " << stringifyScheduleModifier(*modifier);
311   if (simd)
312     p << ", simd";
313   p << ") ";
314 }
315 
316 //===----------------------------------------------------------------------===//
317 // Parser, printer and verifier for ReductionVarList
318 //===----------------------------------------------------------------------===//
319 
320 /// reduction ::= `reduction` `(` reduction-entry-list `)`
321 /// reduction-entry-list ::= reduction-entry
322 ///                        | reduction-entry-list `,` reduction-entry
323 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type
324 static ParseResult
325 parseReductionVarList(OpAsmParser &parser,
326                       SmallVectorImpl<SymbolRefAttr> &symbols,
327                       SmallVectorImpl<OpAsmParser::OperandType> &operands,
328                       SmallVectorImpl<Type> &types) {
329   if (failed(parser.parseLParen()))
330     return failure();
331 
332   do {
333     if (parser.parseAttribute(symbols.emplace_back()) || parser.parseArrow() ||
334         parser.parseOperand(operands.emplace_back()) ||
335         parser.parseColonType(types.emplace_back()))
336       return failure();
337   } while (succeeded(parser.parseOptionalComma()));
338   return parser.parseRParen();
339 }
340 
341 /// Print Reduction clause
342 static void printReductionVarList(OpAsmPrinter &p,
343                                   Optional<ArrayAttr> reductions,
344                                   OperandRange reductionVars) {
345   p << "reduction(";
346   for (unsigned i = 0, e = reductions->size(); i < e; ++i) {
347     if (i != 0)
348       p << ", ";
349     p << (*reductions)[i] << " -> " << reductionVars[i] << " : "
350       << reductionVars[i].getType();
351   }
352   p << ") ";
353 }
354 
355 /// Verifies Reduction Clause
356 static LogicalResult verifyReductionVarList(Operation *op,
357                                             Optional<ArrayAttr> reductions,
358                                             OperandRange reductionVars) {
359   if (!reductionVars.empty()) {
360     if (!reductions || reductions->size() != reductionVars.size())
361       return op->emitOpError()
362              << "expected as many reduction symbol references "
363                 "as reduction variables";
364   } else {
365     if (reductions)
366       return op->emitOpError() << "unexpected reduction symbol references";
367     return success();
368   }
369 
370   DenseSet<Value> accumulators;
371   for (auto args : llvm::zip(reductionVars, *reductions)) {
372     Value accum = std::get<0>(args);
373 
374     if (!accumulators.insert(accum).second)
375       return op->emitOpError() << "accumulator variable used more than once";
376 
377     Type varType = accum.getType().cast<PointerLikeType>();
378     auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>();
379     auto decl =
380         SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef);
381     if (!decl)
382       return op->emitOpError() << "expected symbol reference " << symbolRef
383                                << " to point to a reduction declaration";
384 
385     if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
386       return op->emitOpError()
387              << "expected accumulator (" << varType
388              << ") to be the same type as reduction declaration ("
389              << decl.getAccumulatorType() << ")";
390   }
391 
392   return success();
393 }
394 
395 //===----------------------------------------------------------------------===//
396 // Parser, printer and verifier for Synchronization Hint (2.17.12)
397 //===----------------------------------------------------------------------===//
398 
399 /// Parses a Synchronization Hint clause. The value of hint is an integer
400 /// which is a combination of different hints from `omp_sync_hint_t`.
401 ///
402 /// hint-clause = `hint` `(` hint-value `)`
403 static ParseResult parseSynchronizationHint(OpAsmParser &parser,
404                                             IntegerAttr &hintAttr,
405                                             bool parseKeyword = true) {
406   if (parseKeyword && failed(parser.parseOptionalKeyword("hint"))) {
407     hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
408     return success();
409   }
410 
411   if (failed(parser.parseLParen()))
412     return failure();
413   StringRef hintKeyword;
414   int64_t hint = 0;
415   do {
416     if (failed(parser.parseKeyword(&hintKeyword)))
417       return failure();
418     if (hintKeyword == "uncontended")
419       hint |= 1;
420     else if (hintKeyword == "contended")
421       hint |= 2;
422     else if (hintKeyword == "nonspeculative")
423       hint |= 4;
424     else if (hintKeyword == "speculative")
425       hint |= 8;
426     else
427       return parser.emitError(parser.getCurrentLocation())
428              << hintKeyword << " is not a valid hint";
429   } while (succeeded(parser.parseOptionalComma()));
430   if (failed(parser.parseRParen()))
431     return failure();
432   hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
433   return success();
434 }
435 
436 /// Prints a Synchronization Hint clause
437 static void printSynchronizationHint(OpAsmPrinter &p, Operation *op,
438                                      IntegerAttr hintAttr) {
439   int64_t hint = hintAttr.getInt();
440 
441   if (hint == 0)
442     return;
443 
444   // Helper function to get n-th bit from the right end of `value`
445   auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
446 
447   bool uncontended = bitn(hint, 0);
448   bool contended = bitn(hint, 1);
449   bool nonspeculative = bitn(hint, 2);
450   bool speculative = bitn(hint, 3);
451 
452   SmallVector<StringRef> hints;
453   if (uncontended)
454     hints.push_back("uncontended");
455   if (contended)
456     hints.push_back("contended");
457   if (nonspeculative)
458     hints.push_back("nonspeculative");
459   if (speculative)
460     hints.push_back("speculative");
461 
462   p << "hint(";
463   llvm::interleaveComma(hints, p);
464   p << ") ";
465 }
466 
467 /// Verifies a synchronization hint clause
468 static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
469 
470   // Helper function to get n-th bit from the right end of `value`
471   auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
472 
473   bool uncontended = bitn(hint, 0);
474   bool contended = bitn(hint, 1);
475   bool nonspeculative = bitn(hint, 2);
476   bool speculative = bitn(hint, 3);
477 
478   if (uncontended && contended)
479     return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
480                                 "omp_sync_hint_contended cannot be combined";
481   if (nonspeculative && speculative)
482     return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
483                                 "omp_sync_hint_speculative cannot be combined.";
484   return success();
485 }
486 
487 enum ClauseType {
488   ifClause,
489   numThreadsClause,
490   deviceClause,
491   threadLimitClause,
492   allocateClause,
493   procBindClause,
494   reductionClause,
495   nowaitClause,
496   linearClause,
497   scheduleClause,
498   collapseClause,
499   orderClause,
500   orderedClause,
501   memoryOrderClause,
502   hintClause,
503   COUNT
504 };
505 
506 //===----------------------------------------------------------------------===//
507 // Parser for Clause List
508 //===----------------------------------------------------------------------===//
509 
510 /// Parse a clause attribute `(` $value `)`.
511 template <typename ClauseAttr>
512 static ParseResult parseClauseAttr(AsmParser &parser, OperationState &state,
513                                    StringRef attrName, StringRef name) {
514   using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
515   StringRef enumStr;
516   SMLoc loc = parser.getCurrentLocation();
517   if (parser.parseLParen() || parser.parseKeyword(&enumStr) ||
518       parser.parseRParen())
519     return failure();
520   if (Optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
521     auto attr = ClauseAttr::get(parser.getContext(), *enumValue);
522     state.addAttribute(attrName, attr);
523     return success();
524   }
525   return parser.emitError(loc, "invalid ") << name << " kind";
526 }
527 
528 /// Parse a list of clauses. The clauses can appear in any order, but their
529 /// operand segment indices are in the same order that they are passed in the
530 /// `clauses` list. The operand segments are added over the prevSegments
531 
532 /// clause-list ::= clause clause-list | empty
533 /// clause ::= if | num-threads | allocate | proc-bind | reduction | nowait
534 ///          | linear | schedule | collapse | order | ordered | inclusive
535 /// if ::= `if` `(` ssa-id-and-type `)`
536 /// num-threads ::= `num_threads` `(` ssa-id-and-type `)`
537 /// allocate ::= `allocate` `(` allocate-operand-list `)`
538 /// proc-bind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)`
539 /// reduction ::= `reduction` `(` reduction-entry-list `)`
540 /// nowait ::= `nowait`
541 /// linear ::= `linear` `(` linear-list `)`
542 /// schedule ::= `schedule` `(` sched-list `)`
543 /// collapse ::= `collapse` `(` ssa-id-and-type `)`
544 /// order ::= `order` `(` `concurrent` `)`
545 /// ordered ::= `ordered` `(` ssa-id-and-type `)`
546 /// inclusive ::= `inclusive`
547 ///
548 /// Note that each clause can only appear once in the clase-list.
549 static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
550                                 SmallVectorImpl<ClauseType> &clauses,
551                                 SmallVectorImpl<int> &segments) {
552 
553   // Check done[clause] to see if it has been parsed already
554   BitVector done(ClauseType::COUNT, false);
555 
556   // See pos[clause] to get position of clause in operand segments
557   SmallVector<int> pos(ClauseType::COUNT, -1);
558 
559   // Stores the last parsed clause keyword
560   StringRef clauseKeyword;
561   StringRef opName = result.name.getStringRef();
562 
563   // Containers for storing operands, types and attributes for various clauses
564   std::pair<OpAsmParser::OperandType, Type> ifCond;
565   std::pair<OpAsmParser::OperandType, Type> numThreads;
566   std::pair<OpAsmParser::OperandType, Type> device;
567   std::pair<OpAsmParser::OperandType, Type> threadLimit;
568 
569   SmallVector<OpAsmParser::OperandType> allocates, allocators;
570   SmallVector<Type> allocateTypes, allocatorTypes;
571 
572   SmallVector<SymbolRefAttr> reductionSymbols;
573   SmallVector<OpAsmParser::OperandType> reductionVars;
574   SmallVector<Type> reductionVarTypes;
575 
576   SmallVector<OpAsmParser::OperandType> linears;
577   SmallVector<Type> linearTypes;
578   SmallVector<OpAsmParser::OperandType> linearSteps;
579 
580   SmallString<8> schedule;
581   SmallVector<SmallString<12>> modifiers;
582   Optional<OpAsmParser::OperandType> scheduleChunkSize;
583   Type scheduleChunkType;
584 
585   // Compute the position of clauses in operand segments
586   int currPos = 0;
587   for (ClauseType clause : clauses) {
588 
589     // Skip the following clauses - they do not take any position in operand
590     // segments
591     if (clause == procBindClause || clause == nowaitClause ||
592         clause == collapseClause || clause == orderClause ||
593         clause == orderedClause)
594       continue;
595 
596     pos[clause] = currPos++;
597 
598     // For the following clauses, two positions are reserved in the operand
599     // segments
600     if (clause == allocateClause || clause == linearClause)
601       currPos++;
602   }
603 
604   SmallVector<int> clauseSegments(currPos);
605 
606   // Helper function to check if a clause is allowed/repeated or not
607   auto checkAllowed = [&](ClauseType clause) -> ParseResult {
608     if (!llvm::is_contained(clauses, clause))
609       return parser.emitError(parser.getCurrentLocation())
610              << clauseKeyword << " is not a valid clause for the " << opName
611              << " operation";
612     if (done[clause])
613       return parser.emitError(parser.getCurrentLocation())
614              << "at most one " << clauseKeyword << " clause can appear on the "
615              << opName << " operation";
616     done[clause] = true;
617     return success();
618   };
619 
620   while (succeeded(parser.parseOptionalKeyword(&clauseKeyword))) {
621     if (clauseKeyword == "if") {
622       if (checkAllowed(ifClause) || parser.parseLParen() ||
623           parser.parseOperand(ifCond.first) ||
624           parser.parseColonType(ifCond.second) || parser.parseRParen())
625         return failure();
626       clauseSegments[pos[ifClause]] = 1;
627     } else if (clauseKeyword == "num_threads") {
628       if (checkAllowed(numThreadsClause) || parser.parseLParen() ||
629           parser.parseOperand(numThreads.first) ||
630           parser.parseColonType(numThreads.second) || parser.parseRParen())
631         return failure();
632       clauseSegments[pos[numThreadsClause]] = 1;
633     } else if (clauseKeyword == "device") {
634       if (checkAllowed(deviceClause) || parser.parseLParen() ||
635           parser.parseOperand(device.first) ||
636           parser.parseColonType(device.second) || parser.parseRParen())
637         return failure();
638       clauseSegments[pos[deviceClause]] = 1;
639     } else if (clauseKeyword == "thread_limit") {
640       if (checkAllowed(threadLimitClause) || parser.parseLParen() ||
641           parser.parseOperand(threadLimit.first) ||
642           parser.parseColonType(threadLimit.second) || parser.parseRParen())
643         return failure();
644       clauseSegments[pos[threadLimitClause]] = 1;
645     } else if (clauseKeyword == "allocate") {
646       if (checkAllowed(allocateClause) ||
647           parseAllocateAndAllocator(parser, allocates, allocateTypes,
648                                     allocators, allocatorTypes))
649         return failure();
650       clauseSegments[pos[allocateClause]] = allocates.size();
651       clauseSegments[pos[allocateClause] + 1] = allocators.size();
652     } else if (clauseKeyword == "proc_bind") {
653       if (checkAllowed(procBindClause) ||
654           parseClauseAttr<ClauseProcBindKindAttr>(parser, result,
655                                                   "proc_bind_val", "proc bind"))
656         return failure();
657     } else if (clauseKeyword == "reduction") {
658       if (checkAllowed(reductionClause) ||
659           parseReductionVarList(parser, reductionSymbols, reductionVars,
660                                 reductionVarTypes))
661         return failure();
662       clauseSegments[pos[reductionClause]] = reductionVars.size();
663     } else if (clauseKeyword == "nowait") {
664       if (checkAllowed(nowaitClause))
665         return failure();
666       auto attr = UnitAttr::get(parser.getBuilder().getContext());
667       result.addAttribute("nowait", attr);
668     } else if (clauseKeyword == "linear") {
669       if (checkAllowed(linearClause) ||
670           parseLinearClause(parser, linears, linearTypes, linearSteps))
671         return failure();
672       clauseSegments[pos[linearClause]] = linears.size();
673       clauseSegments[pos[linearClause] + 1] = linearSteps.size();
674     } else if (clauseKeyword == "schedule") {
675       if (checkAllowed(scheduleClause) ||
676           parseScheduleClause(parser, schedule, modifiers, scheduleChunkSize,
677                               scheduleChunkType))
678         return failure();
679       if (scheduleChunkSize) {
680         clauseSegments[pos[scheduleClause]] = 1;
681       }
682     } else if (clauseKeyword == "collapse") {
683       auto type = parser.getBuilder().getI64Type();
684       mlir::IntegerAttr attr;
685       if (checkAllowed(collapseClause) || parser.parseLParen() ||
686           parser.parseAttribute(attr, type) || parser.parseRParen())
687         return failure();
688       result.addAttribute("collapse_val", attr);
689     } else if (clauseKeyword == "ordered") {
690       mlir::IntegerAttr attr;
691       if (checkAllowed(orderedClause))
692         return failure();
693       if (succeeded(parser.parseOptionalLParen())) {
694         auto type = parser.getBuilder().getI64Type();
695         if (parser.parseAttribute(attr, type) || parser.parseRParen())
696           return failure();
697       } else {
698         // Use 0 to represent no ordered parameter was specified
699         attr = parser.getBuilder().getI64IntegerAttr(0);
700       }
701       result.addAttribute("ordered_val", attr);
702     } else if (clauseKeyword == "order") {
703       if (checkAllowed(orderClause) ||
704           parseClauseAttr<ClauseOrderKindAttr>(parser, result, "order_val",
705                                                "order"))
706         return failure();
707     } else if (clauseKeyword == "memory_order") {
708       if (checkAllowed(memoryOrderClause) ||
709           parseClauseAttr<ClauseMemoryOrderKindAttr>(
710               parser, result, "memory_order", "memory order"))
711         return failure();
712     } else if (clauseKeyword == "hint") {
713       IntegerAttr hint;
714       if (checkAllowed(hintClause) ||
715           parseSynchronizationHint(parser, hint, false))
716         return failure();
717       result.addAttribute("hint", hint);
718     } else {
719       return parser.emitError(parser.getNameLoc())
720              << clauseKeyword << " is not a valid clause";
721     }
722   }
723 
724   // Add if parameter.
725   if (done[ifClause] && clauseSegments[pos[ifClause]] &&
726       failed(
727           parser.resolveOperand(ifCond.first, ifCond.second, result.operands)))
728     return failure();
729 
730   // Add num_threads parameter.
731   if (done[numThreadsClause] && clauseSegments[pos[numThreadsClause]] &&
732       failed(parser.resolveOperand(numThreads.first, numThreads.second,
733                                    result.operands)))
734     return failure();
735 
736   // Add device parameter.
737   if (done[deviceClause] && clauseSegments[pos[deviceClause]] &&
738       failed(
739           parser.resolveOperand(device.first, device.second, result.operands)))
740     return failure();
741 
742   // Add thread_limit parameter.
743   if (done[threadLimitClause] && clauseSegments[pos[threadLimitClause]] &&
744       failed(parser.resolveOperand(threadLimit.first, threadLimit.second,
745                                    result.operands)))
746     return failure();
747 
748   // Add allocate parameters.
749   if (done[allocateClause] && clauseSegments[pos[allocateClause]] &&
750       failed(parser.resolveOperands(allocates, allocateTypes,
751                                     allocates[0].location, result.operands)))
752     return failure();
753 
754   // Add allocator parameters.
755   if (done[allocateClause] && clauseSegments[pos[allocateClause] + 1] &&
756       failed(parser.resolveOperands(allocators, allocatorTypes,
757                                     allocators[0].location, result.operands)))
758     return failure();
759 
760   // Add reduction parameters and symbols
761   if (done[reductionClause] && clauseSegments[pos[reductionClause]]) {
762     if (failed(parser.resolveOperands(reductionVars, reductionVarTypes,
763                                       parser.getNameLoc(), result.operands)))
764       return failure();
765 
766     SmallVector<Attribute> reductions(reductionSymbols.begin(),
767                                       reductionSymbols.end());
768     result.addAttribute("reductions",
769                         parser.getBuilder().getArrayAttr(reductions));
770   }
771 
772   // Add linear parameters
773   if (done[linearClause] && clauseSegments[pos[linearClause]]) {
774     auto linearStepType = parser.getBuilder().getI32Type();
775     SmallVector<Type> linearStepTypes(linearSteps.size(), linearStepType);
776     if (failed(parser.resolveOperands(linears, linearTypes, linears[0].location,
777                                       result.operands)) ||
778         failed(parser.resolveOperands(linearSteps, linearStepTypes,
779                                       linearSteps[0].location,
780                                       result.operands)))
781       return failure();
782   }
783 
784   // Add schedule parameters
785   if (done[scheduleClause] && !schedule.empty()) {
786     schedule[0] = llvm::toUpper(schedule[0]);
787     if (Optional<ClauseScheduleKind> sched =
788             symbolizeClauseScheduleKind(schedule)) {
789       auto attr = ClauseScheduleKindAttr::get(parser.getContext(), *sched);
790       result.addAttribute("schedule_val", attr);
791     } else {
792       return parser.emitError(parser.getCurrentLocation(),
793                               "invalid schedule kind");
794     }
795     if (!modifiers.empty()) {
796       SMLoc loc = parser.getCurrentLocation();
797       if (Optional<ScheduleModifier> mod =
798               symbolizeScheduleModifier(modifiers[0])) {
799         result.addAttribute(
800             "schedule_modifier",
801             ScheduleModifierAttr::get(parser.getContext(), *mod));
802       } else {
803         return parser.emitError(loc, "invalid schedule modifier");
804       }
805       // Only SIMD attribute is allowed here!
806       if (modifiers.size() > 1) {
807         assert(symbolizeScheduleModifier(modifiers[1]) ==
808                ScheduleModifier::simd);
809         auto attr = UnitAttr::get(parser.getBuilder().getContext());
810         result.addAttribute("simd_modifier", attr);
811       }
812     }
813     if (scheduleChunkSize)
814       parser.resolveOperand(*scheduleChunkSize, scheduleChunkType,
815                             result.operands);
816   }
817 
818   segments.insert(segments.end(), clauseSegments.begin(), clauseSegments.end());
819 
820   return success();
821 }
822 
823 /// Parses a parallel operation.
824 ///
825 /// operation ::= `omp.parallel` clause-list
826 /// clause-list ::= clause | clause clause-list
827 /// clause ::= if | num-threads | allocate | proc-bind
828 ///
829 ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
830   SmallVector<ClauseType> clauses = {ifClause, numThreadsClause, allocateClause,
831                                      procBindClause};
832 
833   SmallVector<int> segments;
834 
835   if (failed(parseClauses(parser, result, clauses, segments)))
836     return failure();
837 
838   result.addAttribute("operand_segment_sizes",
839                       parser.getBuilder().getI32VectorAttr(segments));
840 
841   Region *body = result.addRegion();
842   SmallVector<OpAsmParser::OperandType> regionArgs;
843   SmallVector<Type> regionArgTypes;
844   if (parser.parseRegion(*body, regionArgs, regionArgTypes))
845     return failure();
846   return success();
847 }
848 
849 /// Parses a target operation.
850 ///
851 /// operation ::= `omp.target` clause-list
852 /// clause-list ::= clause | clause clause-list
853 /// clause ::= if | device | thread_limit | nowait
854 ///
855 ParseResult TargetOp::parse(OpAsmParser &parser, OperationState &result) {
856   SmallVector<ClauseType> clauses = {ifClause, deviceClause, threadLimitClause,
857                                      nowaitClause};
858 
859   SmallVector<int> segments;
860 
861   if (failed(parseClauses(parser, result, clauses, segments)))
862     return failure();
863 
864   result.addAttribute(
865       TargetOp::AttrSizedOperandSegments::getOperandSegmentSizeAttr(),
866       parser.getBuilder().getI32VectorAttr(segments));
867 
868   Region *body = result.addRegion();
869   SmallVector<OpAsmParser::OperandType> regionArgs;
870   SmallVector<Type> regionArgTypes;
871   if (parser.parseRegion(*body, regionArgs, regionArgTypes))
872     return failure();
873   return success();
874 }
875 
876 //===----------------------------------------------------------------------===//
877 // Parser, printer and verifier for SectionsOp
878 //===----------------------------------------------------------------------===//
879 
880 /// Parses an OpenMP Sections operation
881 ///
882 /// sections ::= `omp.sections` clause-list
883 /// clause-list ::= clause clause-list | empty
884 /// clause ::= reduction | allocate | nowait
885 ParseResult SectionsOp::parse(OpAsmParser &parser, OperationState &result) {
886   SmallVector<ClauseType> clauses = {reductionClause, allocateClause,
887                                      nowaitClause};
888 
889   SmallVector<int> segments;
890 
891   if (failed(parseClauses(parser, result, clauses, segments)))
892     return failure();
893 
894   result.addAttribute("operand_segment_sizes",
895                       parser.getBuilder().getI32VectorAttr(segments));
896 
897   // Now parse the body.
898   Region *body = result.addRegion();
899   if (parser.parseRegion(*body))
900     return failure();
901   return success();
902 }
903 
904 void SectionsOp::print(OpAsmPrinter &p) {
905   p << " ";
906 
907   if (!reduction_vars().empty())
908     printReductionVarList(p, reductions(), reduction_vars());
909 
910   if (!allocate_vars().empty())
911     printAllocateAndAllocator(p, allocate_vars(), allocators_vars());
912 
913   if (nowait())
914     p << "nowait";
915 
916   p << ' ';
917   p.printRegion(region());
918 }
919 
920 LogicalResult SectionsOp::verify() {
921   if (allocate_vars().size() != allocators_vars().size())
922     return emitError(
923         "expected equal sizes for allocate and allocator variables");
924 
925   for (auto &inst : *region().begin()) {
926     if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
927       return emitOpError()
928              << "expected omp.section op or terminator op inside region";
929     }
930   }
931 
932   return verifyReductionVarList(*this, reductions(), reduction_vars());
933 }
934 
935 /// Parses an OpenMP Workshare Loop operation
936 ///
937 /// wsloop ::= `omp.wsloop` loop-control clause-list
938 /// loop-control ::= `(` ssa-id-list `)` `:` type `=`  loop-bounds
939 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` inclusive? steps
940 /// steps := `step` `(`ssa-id-list`)`
941 /// clause-list ::= clause clause-list | empty
942 /// clause ::= linear | schedule | collapse | nowait | ordered | order
943 ///          | reduction
944 ParseResult WsLoopOp::parse(OpAsmParser &parser, OperationState &result) {
945   // Parse an opening `(` followed by induction variables followed by `)`
946   SmallVector<OpAsmParser::OperandType> ivs;
947   if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
948                                      OpAsmParser::Delimiter::Paren))
949     return failure();
950 
951   int numIVs = static_cast<int>(ivs.size());
952   Type loopVarType;
953   if (parser.parseColonType(loopVarType))
954     return failure();
955 
956   // Parse loop bounds.
957   SmallVector<OpAsmParser::OperandType> lower;
958   if (parser.parseEqual() ||
959       parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) ||
960       parser.resolveOperands(lower, loopVarType, result.operands))
961     return failure();
962 
963   SmallVector<OpAsmParser::OperandType> upper;
964   if (parser.parseKeyword("to") ||
965       parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) ||
966       parser.resolveOperands(upper, loopVarType, result.operands))
967     return failure();
968 
969   if (succeeded(parser.parseOptionalKeyword("inclusive"))) {
970     auto attr = UnitAttr::get(parser.getBuilder().getContext());
971     result.addAttribute("inclusive", attr);
972   }
973 
974   // Parse step values.
975   SmallVector<OpAsmParser::OperandType> steps;
976   if (parser.parseKeyword("step") ||
977       parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) ||
978       parser.resolveOperands(steps, loopVarType, result.operands))
979     return failure();
980 
981   SmallVector<ClauseType> clauses = {
982       linearClause,  reductionClause, collapseClause, orderClause,
983       orderedClause, nowaitClause,    scheduleClause};
984   SmallVector<int> segments{numIVs, numIVs, numIVs};
985   if (failed(parseClauses(parser, result, clauses, segments)))
986     return failure();
987 
988   result.addAttribute("operand_segment_sizes",
989                       parser.getBuilder().getI32VectorAttr(segments));
990 
991   // Now parse the body.
992   Region *body = result.addRegion();
993   SmallVector<Type> ivTypes(numIVs, loopVarType);
994   SmallVector<OpAsmParser::OperandType> blockArgs(ivs);
995   if (parser.parseRegion(*body, blockArgs, ivTypes))
996     return failure();
997   return success();
998 }
999 
1000 void WsLoopOp::print(OpAsmPrinter &p) {
1001   auto args = getRegion().front().getArguments();
1002   p << " (" << args << ") : " << args[0].getType() << " = (" << lowerBound()
1003     << ") to (" << upperBound() << ") ";
1004   if (inclusive()) {
1005     p << "inclusive ";
1006   }
1007   p << "step (" << step() << ") ";
1008 
1009   if (!linear_vars().empty())
1010     printLinearClause(p, linear_vars(), linear_step_vars());
1011 
1012   if (auto sched = schedule_val())
1013     printScheduleClause(p, sched.getValue(), schedule_modifier(),
1014                         simd_modifier(), schedule_chunk_var());
1015 
1016   if (auto collapse = collapse_val())
1017     p << "collapse(" << collapse << ") ";
1018 
1019   if (nowait())
1020     p << "nowait ";
1021 
1022   if (auto ordered = ordered_val())
1023     p << "ordered(" << ordered << ") ";
1024 
1025   if (auto order = order_val())
1026     p << "order(" << stringifyClauseOrderKind(*order) << ") ";
1027 
1028   if (!reduction_vars().empty())
1029     printReductionVarList(p, reductions(), reduction_vars());
1030 
1031   p << ' ';
1032   p.printRegion(region(), /*printEntryBlockArgs=*/false);
1033 }
1034 
1035 //===----------------------------------------------------------------------===//
1036 // ReductionOp
1037 //===----------------------------------------------------------------------===//
1038 
1039 static ParseResult parseAtomicReductionRegion(OpAsmParser &parser,
1040                                               Region &region) {
1041   if (parser.parseOptionalKeyword("atomic"))
1042     return success();
1043   return parser.parseRegion(region);
1044 }
1045 
1046 static void printAtomicReductionRegion(OpAsmPrinter &printer,
1047                                        ReductionDeclareOp op, Region &region) {
1048   if (region.empty())
1049     return;
1050   printer << "atomic ";
1051   printer.printRegion(region);
1052 }
1053 
1054 LogicalResult ReductionDeclareOp::verify() {
1055   if (initializerRegion().empty())
1056     return emitOpError() << "expects non-empty initializer region";
1057   Block &initializerEntryBlock = initializerRegion().front();
1058   if (initializerEntryBlock.getNumArguments() != 1 ||
1059       initializerEntryBlock.getArgument(0).getType() != type()) {
1060     return emitOpError() << "expects initializer region with one argument "
1061                             "of the reduction type";
1062   }
1063 
1064   for (YieldOp yieldOp : initializerRegion().getOps<YieldOp>()) {
1065     if (yieldOp.results().size() != 1 ||
1066         yieldOp.results().getTypes()[0] != type())
1067       return emitOpError() << "expects initializer region to yield a value "
1068                               "of the reduction type";
1069   }
1070 
1071   if (reductionRegion().empty())
1072     return emitOpError() << "expects non-empty reduction region";
1073   Block &reductionEntryBlock = reductionRegion().front();
1074   if (reductionEntryBlock.getNumArguments() != 2 ||
1075       reductionEntryBlock.getArgumentTypes()[0] !=
1076           reductionEntryBlock.getArgumentTypes()[1] ||
1077       reductionEntryBlock.getArgumentTypes()[0] != type())
1078     return emitOpError() << "expects reduction region with two arguments of "
1079                             "the reduction type";
1080   for (YieldOp yieldOp : reductionRegion().getOps<YieldOp>()) {
1081     if (yieldOp.results().size() != 1 ||
1082         yieldOp.results().getTypes()[0] != type())
1083       return emitOpError() << "expects reduction region to yield a value "
1084                               "of the reduction type";
1085   }
1086 
1087   if (atomicReductionRegion().empty())
1088     return success();
1089 
1090   Block &atomicReductionEntryBlock = atomicReductionRegion().front();
1091   if (atomicReductionEntryBlock.getNumArguments() != 2 ||
1092       atomicReductionEntryBlock.getArgumentTypes()[0] !=
1093           atomicReductionEntryBlock.getArgumentTypes()[1])
1094     return emitOpError() << "expects atomic reduction region with two "
1095                             "arguments of the same type";
1096   auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0]
1097                      .dyn_cast<PointerLikeType>();
1098   if (!ptrType || ptrType.getElementType() != type())
1099     return emitOpError() << "expects atomic reduction region arguments to "
1100                             "be accumulators containing the reduction type";
1101   return success();
1102 }
1103 
1104 LogicalResult ReductionOp::verify() {
1105   // TODO: generalize this to an op interface when there is more than one op
1106   // that supports reductions.
1107   auto container = (*this)->getParentOfType<WsLoopOp>();
1108   for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i)
1109     if (container.reduction_vars()[i] == accumulator())
1110       return success();
1111 
1112   return emitOpError() << "the accumulator is not used by the parent";
1113 }
1114 
1115 //===----------------------------------------------------------------------===//
1116 // WsLoopOp
1117 //===----------------------------------------------------------------------===//
1118 
1119 void WsLoopOp::build(OpBuilder &builder, OperationState &state,
1120                      ValueRange lowerBound, ValueRange upperBound,
1121                      ValueRange step, ArrayRef<NamedAttribute> attributes) {
1122   build(builder, state, TypeRange(), lowerBound, upperBound, step,
1123         /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
1124         /*reduction_vars=*/ValueRange(), /*schedule_val=*/nullptr,
1125         /*schedule_chunk_var=*/nullptr, /*collapse_val=*/nullptr,
1126         /*nowait=*/nullptr, /*ordered_val=*/nullptr, /*order_val=*/nullptr,
1127         /*inclusive=*/nullptr, /*buildBody=*/false);
1128   state.addAttributes(attributes);
1129 }
1130 
1131 void WsLoopOp::build(OpBuilder &, OperationState &state, TypeRange resultTypes,
1132                      ValueRange operands, ArrayRef<NamedAttribute> attributes) {
1133   state.addOperands(operands);
1134   state.addAttributes(attributes);
1135   (void)state.addRegion();
1136   assert(resultTypes.empty() && "mismatched number of return types");
1137   state.addTypes(resultTypes);
1138 }
1139 
1140 void WsLoopOp::build(OpBuilder &builder, OperationState &result,
1141                      TypeRange typeRange, ValueRange lowerBounds,
1142                      ValueRange upperBounds, ValueRange steps,
1143                      ValueRange linearVars, ValueRange linearStepVars,
1144                      ValueRange reductionVars, StringAttr scheduleVal,
1145                      Value scheduleChunkVar, IntegerAttr collapseVal,
1146                      UnitAttr nowait, IntegerAttr orderedVal,
1147                      StringAttr orderVal, UnitAttr inclusive, bool buildBody) {
1148   result.addOperands(lowerBounds);
1149   result.addOperands(upperBounds);
1150   result.addOperands(steps);
1151   result.addOperands(linearVars);
1152   result.addOperands(linearStepVars);
1153   if (scheduleChunkVar)
1154     result.addOperands(scheduleChunkVar);
1155 
1156   if (scheduleVal)
1157     result.addAttribute("schedule_val", scheduleVal);
1158   if (collapseVal)
1159     result.addAttribute("collapse_val", collapseVal);
1160   if (nowait)
1161     result.addAttribute("nowait", nowait);
1162   if (orderedVal)
1163     result.addAttribute("ordered_val", orderedVal);
1164   if (orderVal)
1165     result.addAttribute("order", orderVal);
1166   if (inclusive)
1167     result.addAttribute("inclusive", inclusive);
1168   result.addAttribute(
1169       WsLoopOp::getOperandSegmentSizeAttr(),
1170       builder.getI32VectorAttr(
1171           {static_cast<int32_t>(lowerBounds.size()),
1172            static_cast<int32_t>(upperBounds.size()),
1173            static_cast<int32_t>(steps.size()),
1174            static_cast<int32_t>(linearVars.size()),
1175            static_cast<int32_t>(linearStepVars.size()),
1176            static_cast<int32_t>(reductionVars.size()),
1177            static_cast<int32_t>(scheduleChunkVar != nullptr ? 1 : 0)}));
1178 
1179   Region *bodyRegion = result.addRegion();
1180   if (buildBody) {
1181     OpBuilder::InsertionGuard guard(builder);
1182     unsigned numIVs = steps.size();
1183     SmallVector<Type, 8> argTypes(numIVs, steps.getType().front());
1184     SmallVector<Location, 8> argLocs(numIVs, result.location);
1185     builder.createBlock(bodyRegion, {}, argTypes, argLocs);
1186   }
1187 }
1188 
1189 LogicalResult WsLoopOp::verify() {
1190   return verifyReductionVarList(*this, reductions(), reduction_vars());
1191 }
1192 
1193 //===----------------------------------------------------------------------===//
1194 // Verifier for critical construct (2.17.1)
1195 //===----------------------------------------------------------------------===//
1196 
1197 LogicalResult CriticalDeclareOp::verify() {
1198   return verifySynchronizationHint(*this, hint());
1199 }
1200 
1201 LogicalResult CriticalOp::verify() {
1202   if (nameAttr()) {
1203     SymbolRefAttr symbolRef = nameAttr();
1204     auto decl = SymbolTable::lookupNearestSymbolFrom<CriticalDeclareOp>(
1205         *this, symbolRef);
1206     if (!decl) {
1207       return emitOpError() << "expected symbol reference " << symbolRef
1208                            << " to point to a critical declaration";
1209     }
1210   }
1211 
1212   return success();
1213 }
1214 
1215 //===----------------------------------------------------------------------===//
1216 // Verifier for ordered construct
1217 //===----------------------------------------------------------------------===//
1218 
1219 LogicalResult OrderedOp::verify() {
1220   auto container = (*this)->getParentOfType<WsLoopOp>();
1221   if (!container || !container.ordered_valAttr() ||
1222       container.ordered_valAttr().getInt() == 0)
1223     return emitOpError() << "ordered depend directive must be closely "
1224                          << "nested inside a worksharing-loop with ordered "
1225                          << "clause with parameter present";
1226 
1227   if (container.ordered_valAttr().getInt() !=
1228       (int64_t)num_loops_val().getValue())
1229     return emitOpError() << "number of variables in depend clause does not "
1230                          << "match number of iteration variables in the "
1231                          << "doacross loop";
1232 
1233   return success();
1234 }
1235 
1236 LogicalResult OrderedRegionOp::verify() {
1237   // TODO: The code generation for ordered simd directive is not supported yet.
1238   if (simd())
1239     return failure();
1240 
1241   if (auto container = (*this)->getParentOfType<WsLoopOp>()) {
1242     if (!container.ordered_valAttr() ||
1243         container.ordered_valAttr().getInt() != 0)
1244       return emitOpError() << "ordered region must be closely nested inside "
1245                            << "a worksharing-loop region with an ordered "
1246                            << "clause without parameter present";
1247   }
1248 
1249   return success();
1250 }
1251 
1252 //===----------------------------------------------------------------------===//
1253 // AtomicReadOp
1254 //===----------------------------------------------------------------------===//
1255 
1256 /// Parser for AtomicReadOp
1257 ///
1258 /// operation ::= `omp.atomic.read` atomic-clause-list address `->` result-type
1259 /// address ::= operand `:` type
1260 ParseResult AtomicReadOp::parse(OpAsmParser &parser, OperationState &result) {
1261   OpAsmParser::OperandType x, v;
1262   Type addressType;
1263   SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
1264   SmallVector<int> segments;
1265 
1266   if (parser.parseOperand(v) || parser.parseEqual() || parser.parseOperand(x) ||
1267       parseClauses(parser, result, clauses, segments) ||
1268       parser.parseColonType(addressType) ||
1269       parser.resolveOperand(x, addressType, result.operands) ||
1270       parser.resolveOperand(v, addressType, result.operands))
1271     return failure();
1272 
1273   return success();
1274 }
1275 
1276 void AtomicReadOp::print(OpAsmPrinter &p) {
1277   p << " " << v() << " = " << x() << " ";
1278   if (auto mo = memory_order())
1279     p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") ";
1280   if (hintAttr())
1281     printSynchronizationHint(p << " ", *this, hintAttr());
1282   p << ": " << x().getType();
1283 }
1284 
1285 /// Verifier for AtomicReadOp
1286 LogicalResult AtomicReadOp::verify() {
1287   if (auto mo = memory_order()) {
1288     if (*mo == ClauseMemoryOrderKind::acq_rel ||
1289         *mo == ClauseMemoryOrderKind::release) {
1290       return emitError(
1291           "memory-order must not be acq_rel or release for atomic reads");
1292     }
1293   }
1294   if (x() == v())
1295     return emitError(
1296         "read and write must not be to the same location for atomic reads");
1297   return verifySynchronizationHint(*this, hint());
1298 }
1299 
1300 //===----------------------------------------------------------------------===//
1301 // AtomicWriteOp
1302 //===----------------------------------------------------------------------===//
1303 
1304 /// Parser for AtomicWriteOp
1305 ///
1306 /// operation ::= `omp.atomic.write` atomic-clause-list operands
1307 /// operands ::= address `,` value
1308 /// address ::= operand `:` type
1309 /// value ::= operand `:` type
1310 ParseResult AtomicWriteOp::parse(OpAsmParser &parser, OperationState &result) {
1311   OpAsmParser::OperandType address, value;
1312   Type addrType, valueType;
1313   SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
1314   SmallVector<int> segments;
1315 
1316   if (parser.parseOperand(address) || parser.parseEqual() ||
1317       parser.parseOperand(value) ||
1318       parseClauses(parser, result, clauses, segments) ||
1319       parser.parseColonType(addrType) || parser.parseComma() ||
1320       parser.parseType(valueType) ||
1321       parser.resolveOperand(address, addrType, result.operands) ||
1322       parser.resolveOperand(value, valueType, result.operands))
1323     return failure();
1324   return success();
1325 }
1326 
1327 void AtomicWriteOp::print(OpAsmPrinter &p) {
1328   p << " " << address() << " = " << value() << " ";
1329   if (auto mo = memory_order())
1330     p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") ";
1331   if (hintAttr())
1332     printSynchronizationHint(p, *this, hintAttr());
1333   p << ": " << address().getType() << ", " << value().getType();
1334 }
1335 
1336 /// Verifier for AtomicWriteOp
1337 LogicalResult AtomicWriteOp::verify() {
1338   if (auto mo = memory_order()) {
1339     if (*mo == ClauseMemoryOrderKind::acq_rel ||
1340         *mo == ClauseMemoryOrderKind::acquire) {
1341       return emitError(
1342           "memory-order must not be acq_rel or acquire for atomic writes");
1343     }
1344   }
1345   return verifySynchronizationHint(*this, hint());
1346 }
1347 
1348 //===----------------------------------------------------------------------===//
1349 // AtomicUpdateOp
1350 //===----------------------------------------------------------------------===//
1351 
1352 /// Parser for AtomicUpdateOp
1353 ///
1354 /// operation ::= `omp.atomic.update` atomic-clause-list ssa-id-and-type region
1355 ParseResult AtomicUpdateOp::parse(OpAsmParser &parser, OperationState &result) {
1356   SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
1357   SmallVector<int> segments;
1358   OpAsmParser::OperandType x, expr;
1359   Type xType;
1360 
1361   if (parseClauses(parser, result, clauses, segments) ||
1362       parser.parseOperand(x) || parser.parseColon() ||
1363       parser.parseType(xType) ||
1364       parser.resolveOperand(x, xType, result.operands) ||
1365       parser.parseRegion(*result.addRegion()))
1366     return failure();
1367   return success();
1368 }
1369 
1370 void AtomicUpdateOp::print(OpAsmPrinter &p) {
1371   p << " ";
1372   if (auto mo = memory_order())
1373     p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") ";
1374   if (hintAttr())
1375     printSynchronizationHint(p, *this, hintAttr());
1376   p << x() << " : " << x().getType();
1377   p.printRegion(region());
1378 }
1379 
1380 /// Verifier for AtomicUpdateOp
1381 LogicalResult AtomicUpdateOp::verify() {
1382   if (auto mo = memory_order()) {
1383     if (*mo == ClauseMemoryOrderKind::acq_rel ||
1384         *mo == ClauseMemoryOrderKind::acquire) {
1385       return emitError(
1386           "memory-order must not be acq_rel or acquire for atomic updates");
1387     }
1388   }
1389 
1390   if (region().getNumArguments() != 1)
1391     return emitError("the region must accept exactly one argument");
1392 
1393   if (x().getType().cast<PointerLikeType>().getElementType() !=
1394       region().getArgument(0).getType()) {
1395     return emitError("the type of the operand must be a pointer type whose "
1396                      "element type is the same as that of the region argument");
1397   }
1398 
1399   YieldOp yieldOp = *region().getOps<YieldOp>().begin();
1400 
1401   if (yieldOp.results().size() != 1)
1402     return emitError("only updated value must be returned");
1403   if (yieldOp.results().front().getType() != region().getArgument(0).getType())
1404     return emitError("input and yielded value must have the same type");
1405   return success();
1406 }
1407 
1408 //===----------------------------------------------------------------------===//
1409 // AtomicCaptureOp
1410 //===----------------------------------------------------------------------===//
1411 
1412 ParseResult AtomicCaptureOp::parse(OpAsmParser &parser,
1413                                    OperationState &result) {
1414   SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
1415   SmallVector<int> segments;
1416   if (parseClauses(parser, result, clauses, segments) ||
1417       parser.parseRegion(*result.addRegion()))
1418     return failure();
1419   return success();
1420 }
1421 
1422 void AtomicCaptureOp::print(OpAsmPrinter &p) {
1423   if (memory_order())
1424     p << "memory_order(" << memory_order() << ") ";
1425   if (hintAttr())
1426     printSynchronizationHint(p, *this, hintAttr());
1427   p.printRegion(region());
1428 }
1429 
1430 /// Verifier for AtomicCaptureOp
1431 LogicalResult AtomicCaptureOp::verify() {
1432   Block::OpListType &ops = region().front().getOperations();
1433   if (ops.size() != 3)
1434     return emitError()
1435            << "expected three operations in omp.atomic.capture region (one "
1436               "terminator, and two atomic ops)";
1437   auto &firstOp = ops.front();
1438   auto &secondOp = *ops.getNextNode(firstOp);
1439   auto firstReadStmt = dyn_cast<AtomicReadOp>(firstOp);
1440   auto firstUpdateStmt = dyn_cast<AtomicUpdateOp>(firstOp);
1441   auto secondReadStmt = dyn_cast<AtomicReadOp>(secondOp);
1442   auto secondUpdateStmt = dyn_cast<AtomicUpdateOp>(secondOp);
1443   auto secondWriteStmt = dyn_cast<AtomicWriteOp>(secondOp);
1444 
1445   if (!((firstUpdateStmt && secondReadStmt) ||
1446         (firstReadStmt && secondUpdateStmt) ||
1447         (firstReadStmt && secondWriteStmt)))
1448     return ops.front().emitError()
1449            << "invalid sequence of operations in the capture region";
1450   if (firstUpdateStmt && secondReadStmt &&
1451       firstUpdateStmt.x() != secondReadStmt.x())
1452     return firstUpdateStmt.emitError()
1453            << "updated variable in omp.atomic.update must be captured in "
1454               "second operation";
1455   if (firstReadStmt && secondUpdateStmt &&
1456       firstReadStmt.x() != secondUpdateStmt.x())
1457     return firstReadStmt.emitError()
1458            << "captured variable in omp.atomic.read must be updated in second "
1459               "operation";
1460   if (firstReadStmt && secondWriteStmt &&
1461       firstReadStmt.x() != secondWriteStmt.address())
1462     return firstReadStmt.emitError()
1463            << "captured variable in omp.atomic.read must be updated in "
1464               "second operation";
1465   return success();
1466 }
1467 
1468 #define GET_ATTRDEF_CLASSES
1469 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
1470 
1471 #define GET_OP_CLASSES
1472 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
1473