1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the OpenMP dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
15 #include "mlir/IR/Attributes.h"
16 #include "mlir/IR/DialectImplementation.h"
17 #include "mlir/IR/OpImplementation.h"
18 #include "mlir/IR/OperationSupport.h"
19 
20 #include "llvm/ADT/BitVector.h"
21 #include "llvm/ADT/SmallString.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/ADT/StringSwitch.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include <cstddef>
27 
28 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
29 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
30 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
31 
32 using namespace mlir;
33 using namespace mlir::omp;
34 
35 namespace {
36 /// Model for pointer-like types that already provide a `getElementType` method.
37 template <typename T>
38 struct PointerLikeModel
39     : public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> {
40   Type getElementType(Type pointer) const {
41     return pointer.cast<T>().getElementType();
42   }
43 };
44 } // namespace
45 
46 void OpenMPDialect::initialize() {
47   addOperations<
48 #define GET_OP_LIST
49 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
50       >();
51   addAttributes<
52 #define GET_ATTRDEF_LIST
53 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
54       >();
55 
56   LLVM::LLVMPointerType::attachInterface<
57       PointerLikeModel<LLVM::LLVMPointerType>>(*getContext());
58   MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext());
59 }
60 
61 //===----------------------------------------------------------------------===//
62 // ParallelOp
63 //===----------------------------------------------------------------------===//
64 
65 void ParallelOp::build(OpBuilder &builder, OperationState &state,
66                        ArrayRef<NamedAttribute> attributes) {
67   ParallelOp::build(
68       builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
69       /*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(),
70       /*proc_bind_val=*/nullptr);
71   state.addAttributes(attributes);
72 }
73 
74 //===----------------------------------------------------------------------===//
75 // Parser and printer for Allocate Clause
76 //===----------------------------------------------------------------------===//
77 
78 /// Parse an allocate clause with allocators and a list of operands with types.
79 ///
80 /// allocate-operand-list :: = allocate-operand |
81 ///                            allocator-operand `,` allocate-operand-list
82 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
83 /// ssa-id-and-type ::= ssa-id `:` type
84 static ParseResult parseAllocateAndAllocator(
85     OpAsmParser &parser,
86     SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate,
87     SmallVectorImpl<Type> &typesAllocate,
88     SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator,
89     SmallVectorImpl<Type> &typesAllocator) {
90 
91   return parser.parseCommaSeparatedList([&]() -> ParseResult {
92     OpAsmParser::OperandType operand;
93     Type type;
94     if (parser.parseOperand(operand) || parser.parseColonType(type))
95       return failure();
96     operandsAllocator.push_back(operand);
97     typesAllocator.push_back(type);
98     if (parser.parseArrow())
99       return failure();
100     if (parser.parseOperand(operand) || parser.parseColonType(type))
101       return failure();
102 
103     operandsAllocate.push_back(operand);
104     typesAllocate.push_back(type);
105     return success();
106   });
107 }
108 
109 /// Print allocate clause
110 static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op,
111                                       OperandRange varsAllocate,
112                                       TypeRange typesAllocate,
113                                       OperandRange varsAllocator,
114                                       TypeRange typesAllocator) {
115   for (unsigned i = 0; i < varsAllocate.size(); ++i) {
116     std::string separator = i == varsAllocate.size() - 1 ? "" : ", ";
117     p << varsAllocator[i] << " : " << typesAllocator[i] << " -> ";
118     p << varsAllocate[i] << " : " << typesAllocate[i] << separator;
119   }
120 }
121 
122 ParseResult parseProcBindKind(OpAsmParser &parser,
123                               omp::ClauseProcBindKindAttr &procBindAttr) {
124   StringRef procBindStr;
125   if (parser.parseKeyword(&procBindStr))
126     return failure();
127   if (auto procBindVal = symbolizeClauseProcBindKind(procBindStr)) {
128     procBindAttr =
129         ClauseProcBindKindAttr::get(parser.getContext(), *procBindVal);
130     return success();
131   }
132   return failure();
133 }
134 
135 void printProcBindKind(OpAsmPrinter &p, Operation *op,
136                        omp::ClauseProcBindKindAttr procBindAttr) {
137   p << stringifyClauseProcBindKind(procBindAttr.getValue());
138 }
139 
140 LogicalResult ParallelOp::verify() {
141   if (allocate_vars().size() != allocators_vars().size())
142     return emitError(
143         "expected equal sizes for allocate and allocator variables");
144   return success();
145 }
146 
147 //===----------------------------------------------------------------------===//
148 // Parser and printer for Linear Clause
149 //===----------------------------------------------------------------------===//
150 
151 /// linear ::= `linear` `(` linear-list `)`
152 /// linear-list := linear-val | linear-val linear-list
153 /// linear-val := ssa-id-and-type `=` ssa-id-and-type
154 static ParseResult
155 parseLinearClause(OpAsmParser &parser,
156                   SmallVectorImpl<OpAsmParser::OperandType> &vars,
157                   SmallVectorImpl<Type> &types,
158                   SmallVectorImpl<OpAsmParser::OperandType> &stepVars) {
159   if (parser.parseLParen())
160     return failure();
161 
162   do {
163     OpAsmParser::OperandType var;
164     Type type;
165     OpAsmParser::OperandType stepVar;
166     if (parser.parseOperand(var) || parser.parseEqual() ||
167         parser.parseOperand(stepVar) || parser.parseColonType(type))
168       return failure();
169 
170     vars.push_back(var);
171     types.push_back(type);
172     stepVars.push_back(stepVar);
173   } while (succeeded(parser.parseOptionalComma()));
174 
175   if (parser.parseRParen())
176     return failure();
177 
178   return success();
179 }
180 
181 /// Print Linear Clause
182 static void printLinearClause(OpAsmPrinter &p, OperandRange linearVars,
183                               OperandRange linearStepVars) {
184   size_t linearVarsSize = linearVars.size();
185   p << "linear(";
186   for (unsigned i = 0; i < linearVarsSize; ++i) {
187     std::string separator = i == linearVarsSize - 1 ? ") " : ", ";
188     p << linearVars[i];
189     if (linearStepVars.size() > i)
190       p << " = " << linearStepVars[i];
191     p << " : " << linearVars[i].getType() << separator;
192   }
193 }
194 
195 //===----------------------------------------------------------------------===//
196 // Parser and printer for Schedule Clause
197 //===----------------------------------------------------------------------===//
198 
199 static ParseResult
200 verifyScheduleModifiers(OpAsmParser &parser,
201                         SmallVectorImpl<SmallString<12>> &modifiers) {
202   if (modifiers.size() > 2)
203     return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)";
204   for (const auto &mod : modifiers) {
205     // Translate the string. If it has no value, then it was not a valid
206     // modifier!
207     auto symbol = symbolizeScheduleModifier(mod);
208     if (!symbol.hasValue())
209       return parser.emitError(parser.getNameLoc())
210              << " unknown modifier type: " << mod;
211   }
212 
213   // If we have one modifier that is "simd", then stick a "none" modiifer in
214   // index 0.
215   if (modifiers.size() == 1) {
216     if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
217       modifiers.push_back(modifiers[0]);
218       modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
219     }
220   } else if (modifiers.size() == 2) {
221     // If there are two modifier:
222     // First modifier should not be simd, second one should be simd
223     if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
224         symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
225       return parser.emitError(parser.getNameLoc())
226              << " incorrect modifier order";
227   }
228   return success();
229 }
230 
231 /// schedule ::= `schedule` `(` sched-list `)`
232 /// sched-list ::= sched-val | sched-val sched-list |
233 ///                sched-val `,` sched-modifier
234 /// sched-val ::= sched-with-chunk | sched-wo-chunk
235 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
236 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
237 /// sched-wo-chunk ::=  `auto` | `runtime`
238 /// sched-modifier ::=  sched-mod-val | sched-mod-val `,` sched-mod-val
239 /// sched-mod-val ::=  `monotonic` | `nonmonotonic` | `simd` | `none`
240 static ParseResult
241 parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule,
242                     SmallVectorImpl<SmallString<12>> &modifiers,
243                     Optional<OpAsmParser::OperandType> &chunkSize,
244                     Type &chunkType) {
245   if (parser.parseLParen())
246     return failure();
247 
248   StringRef keyword;
249   if (parser.parseKeyword(&keyword))
250     return failure();
251 
252   schedule = keyword;
253   if (keyword == "static" || keyword == "dynamic" || keyword == "guided") {
254     if (succeeded(parser.parseOptionalEqual())) {
255       chunkSize = OpAsmParser::OperandType{};
256       if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType))
257         return failure();
258     } else {
259       chunkSize = llvm::NoneType::None;
260     }
261   } else if (keyword == "auto" || keyword == "runtime") {
262     chunkSize = llvm::NoneType::None;
263   } else {
264     return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
265   }
266 
267   // If there is a comma, we have one or more modifiers..
268   while (succeeded(parser.parseOptionalComma())) {
269     StringRef mod;
270     if (parser.parseKeyword(&mod))
271       return failure();
272     modifiers.push_back(mod);
273   }
274 
275   if (parser.parseRParen())
276     return failure();
277 
278   if (verifyScheduleModifiers(parser, modifiers))
279     return failure();
280 
281   return success();
282 }
283 
284 /// Print schedule clause
285 static void printScheduleClause(OpAsmPrinter &p, ClauseScheduleKind sched,
286                                 Optional<ScheduleModifier> modifier, bool simd,
287                                 Value scheduleChunkVar) {
288   p << "schedule(" << stringifyClauseScheduleKind(sched).lower();
289   if (scheduleChunkVar)
290     p << " = " << scheduleChunkVar << " : " << scheduleChunkVar.getType();
291   if (modifier)
292     p << ", " << stringifyScheduleModifier(*modifier);
293   if (simd)
294     p << ", simd";
295   p << ") ";
296 }
297 
298 //===----------------------------------------------------------------------===//
299 // Parser, printer and verifier for ReductionVarList
300 //===----------------------------------------------------------------------===//
301 
302 /// reduction-entry-list ::= reduction-entry
303 ///                        | reduction-entry-list `,` reduction-entry
304 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type
305 static ParseResult parseReductionVarList(
306     OpAsmParser &parser, SmallVectorImpl<OpAsmParser::OperandType> &operands,
307     SmallVectorImpl<Type> &types, ArrayAttr &redcuctionSymbols) {
308   SmallVector<SymbolRefAttr> reductionVec;
309   do {
310     if (parser.parseAttribute(reductionVec.emplace_back()) ||
311         parser.parseArrow() || parser.parseOperand(operands.emplace_back()) ||
312         parser.parseColonType(types.emplace_back()))
313       return failure();
314   } while (succeeded(parser.parseOptionalComma()));
315   SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
316   redcuctionSymbols = ArrayAttr::get(parser.getContext(), reductions);
317   return success();
318 }
319 
320 /// Print Reduction clause
321 static void printReductionVarList(OpAsmPrinter &p, Operation *op,
322                                   OperandRange reductionVars,
323                                   TypeRange reductionTypes,
324                                   Optional<ArrayAttr> reductions) {
325   for (unsigned i = 0, e = reductions->size(); i < e; ++i) {
326     if (i != 0)
327       p << ", ";
328     p << (*reductions)[i] << " -> " << reductionVars[i] << " : "
329       << reductionVars[i].getType();
330   }
331 }
332 
333 /// Verifies Reduction Clause
334 static LogicalResult verifyReductionVarList(Operation *op,
335                                             Optional<ArrayAttr> reductions,
336                                             OperandRange reductionVars) {
337   if (!reductionVars.empty()) {
338     if (!reductions || reductions->size() != reductionVars.size())
339       return op->emitOpError()
340              << "expected as many reduction symbol references "
341                 "as reduction variables";
342   } else {
343     if (reductions)
344       return op->emitOpError() << "unexpected reduction symbol references";
345     return success();
346   }
347 
348   DenseSet<Value> accumulators;
349   for (auto args : llvm::zip(reductionVars, *reductions)) {
350     Value accum = std::get<0>(args);
351 
352     if (!accumulators.insert(accum).second)
353       return op->emitOpError() << "accumulator variable used more than once";
354 
355     Type varType = accum.getType().cast<PointerLikeType>();
356     auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>();
357     auto decl =
358         SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef);
359     if (!decl)
360       return op->emitOpError() << "expected symbol reference " << symbolRef
361                                << " to point to a reduction declaration";
362 
363     if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
364       return op->emitOpError()
365              << "expected accumulator (" << varType
366              << ") to be the same type as reduction declaration ("
367              << decl.getAccumulatorType() << ")";
368   }
369 
370   return success();
371 }
372 
373 //===----------------------------------------------------------------------===//
374 // Parser, printer and verifier for Synchronization Hint (2.17.12)
375 //===----------------------------------------------------------------------===//
376 
377 /// Parses a Synchronization Hint clause. The value of hint is an integer
378 /// which is a combination of different hints from `omp_sync_hint_t`.
379 ///
380 /// hint-clause = `hint` `(` hint-value `)`
381 static ParseResult parseSynchronizationHint(OpAsmParser &parser,
382                                             IntegerAttr &hintAttr,
383                                             bool parseKeyword = true) {
384   if (parseKeyword && failed(parser.parseOptionalKeyword("hint"))) {
385     hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
386     return success();
387   }
388 
389   if (failed(parser.parseLParen()))
390     return failure();
391   StringRef hintKeyword;
392   int64_t hint = 0;
393   do {
394     if (failed(parser.parseKeyword(&hintKeyword)))
395       return failure();
396     if (hintKeyword == "uncontended")
397       hint |= 1;
398     else if (hintKeyword == "contended")
399       hint |= 2;
400     else if (hintKeyword == "nonspeculative")
401       hint |= 4;
402     else if (hintKeyword == "speculative")
403       hint |= 8;
404     else
405       return parser.emitError(parser.getCurrentLocation())
406              << hintKeyword << " is not a valid hint";
407   } while (succeeded(parser.parseOptionalComma()));
408   if (failed(parser.parseRParen()))
409     return failure();
410   hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
411   return success();
412 }
413 
414 /// Prints a Synchronization Hint clause
415 static void printSynchronizationHint(OpAsmPrinter &p, Operation *op,
416                                      IntegerAttr hintAttr) {
417   int64_t hint = hintAttr.getInt();
418 
419   if (hint == 0)
420     return;
421 
422   // Helper function to get n-th bit from the right end of `value`
423   auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
424 
425   bool uncontended = bitn(hint, 0);
426   bool contended = bitn(hint, 1);
427   bool nonspeculative = bitn(hint, 2);
428   bool speculative = bitn(hint, 3);
429 
430   SmallVector<StringRef> hints;
431   if (uncontended)
432     hints.push_back("uncontended");
433   if (contended)
434     hints.push_back("contended");
435   if (nonspeculative)
436     hints.push_back("nonspeculative");
437   if (speculative)
438     hints.push_back("speculative");
439 
440   p << "hint(";
441   llvm::interleaveComma(hints, p);
442   p << ") ";
443 }
444 
445 /// Verifies a synchronization hint clause
446 static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
447 
448   // Helper function to get n-th bit from the right end of `value`
449   auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
450 
451   bool uncontended = bitn(hint, 0);
452   bool contended = bitn(hint, 1);
453   bool nonspeculative = bitn(hint, 2);
454   bool speculative = bitn(hint, 3);
455 
456   if (uncontended && contended)
457     return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
458                                 "omp_sync_hint_contended cannot be combined";
459   if (nonspeculative && speculative)
460     return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
461                                 "omp_sync_hint_speculative cannot be combined.";
462   return success();
463 }
464 
465 enum ClauseType {
466   ifClause,
467   numThreadsClause,
468   deviceClause,
469   threadLimitClause,
470   allocateClause,
471   procBindClause,
472   reductionClause,
473   nowaitClause,
474   linearClause,
475   scheduleClause,
476   collapseClause,
477   orderClause,
478   orderedClause,
479   memoryOrderClause,
480   hintClause,
481   COUNT
482 };
483 
484 //===----------------------------------------------------------------------===//
485 // Parser for Clause List
486 //===----------------------------------------------------------------------===//
487 
488 /// Parse a clause attribute `(` $value `)`.
489 template <typename ClauseAttr>
490 static ParseResult parseClauseAttr(AsmParser &parser, OperationState &state,
491                                    StringRef attrName, StringRef name) {
492   using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
493   StringRef enumStr;
494   SMLoc loc = parser.getCurrentLocation();
495   if (parser.parseLParen() || parser.parseKeyword(&enumStr) ||
496       parser.parseRParen())
497     return failure();
498   if (Optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
499     auto attr = ClauseAttr::get(parser.getContext(), *enumValue);
500     state.addAttribute(attrName, attr);
501     return success();
502   }
503   return parser.emitError(loc, "invalid ") << name << " kind";
504 }
505 
506 /// Parse a list of clauses. The clauses can appear in any order, but their
507 /// operand segment indices are in the same order that they are passed in the
508 /// `clauses` list. The operand segments are added over the prevSegments
509 
510 /// clause-list ::= clause clause-list | empty
511 /// clause ::= if | num-threads | allocate | proc-bind | reduction | nowait
512 ///          | linear | schedule | collapse | order | ordered | inclusive
513 /// if ::= `if` `(` ssa-id-and-type `)`
514 /// num-threads ::= `num_threads` `(` ssa-id-and-type `)`
515 /// allocate ::= `allocate` `(` allocate-operand-list `)`
516 /// proc-bind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)`
517 /// reduction ::= `reduction` `(` reduction-entry-list `)`
518 /// nowait ::= `nowait`
519 /// linear ::= `linear` `(` linear-list `)`
520 /// schedule ::= `schedule` `(` sched-list `)`
521 /// collapse ::= `collapse` `(` ssa-id-and-type `)`
522 /// order ::= `order` `(` `concurrent` `)`
523 /// ordered ::= `ordered` `(` ssa-id-and-type `)`
524 /// inclusive ::= `inclusive`
525 ///
526 /// Note that each clause can only appear once in the clase-list.
527 static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
528                                 SmallVectorImpl<ClauseType> &clauses,
529                                 SmallVectorImpl<int> &segments) {
530 
531   // Check done[clause] to see if it has been parsed already
532   BitVector done(ClauseType::COUNT, false);
533 
534   // See pos[clause] to get position of clause in operand segments
535   SmallVector<int> pos(ClauseType::COUNT, -1);
536 
537   // Stores the last parsed clause keyword
538   StringRef clauseKeyword;
539   StringRef opName = result.name.getStringRef();
540 
541   // Containers for storing operands, types and attributes for various clauses
542   std::pair<OpAsmParser::OperandType, Type> ifCond;
543   std::pair<OpAsmParser::OperandType, Type> numThreads;
544   std::pair<OpAsmParser::OperandType, Type> device;
545   std::pair<OpAsmParser::OperandType, Type> threadLimit;
546 
547   SmallVector<OpAsmParser::OperandType> allocates, allocators;
548   SmallVector<Type> allocateTypes, allocatorTypes;
549 
550   ArrayAttr reductions;
551   SmallVector<OpAsmParser::OperandType> reductionVars;
552   SmallVector<Type> reductionVarTypes;
553 
554   SmallVector<OpAsmParser::OperandType> linears;
555   SmallVector<Type> linearTypes;
556   SmallVector<OpAsmParser::OperandType> linearSteps;
557 
558   SmallString<8> schedule;
559   SmallVector<SmallString<12>> modifiers;
560   Optional<OpAsmParser::OperandType> scheduleChunkSize;
561   Type scheduleChunkType;
562 
563   // Compute the position of clauses in operand segments
564   int currPos = 0;
565   for (ClauseType clause : clauses) {
566 
567     // Skip the following clauses - they do not take any position in operand
568     // segments
569     if (clause == procBindClause || clause == nowaitClause ||
570         clause == collapseClause || clause == orderClause ||
571         clause == orderedClause)
572       continue;
573 
574     pos[clause] = currPos++;
575 
576     // For the following clauses, two positions are reserved in the operand
577     // segments
578     if (clause == allocateClause || clause == linearClause)
579       currPos++;
580   }
581 
582   SmallVector<int> clauseSegments(currPos);
583 
584   // Helper function to check if a clause is allowed/repeated or not
585   auto checkAllowed = [&](ClauseType clause) -> ParseResult {
586     if (!llvm::is_contained(clauses, clause))
587       return parser.emitError(parser.getCurrentLocation())
588              << clauseKeyword << " is not a valid clause for the " << opName
589              << " operation";
590     if (done[clause])
591       return parser.emitError(parser.getCurrentLocation())
592              << "at most one " << clauseKeyword << " clause can appear on the "
593              << opName << " operation";
594     done[clause] = true;
595     return success();
596   };
597 
598   while (succeeded(parser.parseOptionalKeyword(&clauseKeyword))) {
599     if (clauseKeyword == "if") {
600       if (checkAllowed(ifClause) || parser.parseLParen() ||
601           parser.parseOperand(ifCond.first) ||
602           parser.parseColonType(ifCond.second) || parser.parseRParen())
603         return failure();
604       clauseSegments[pos[ifClause]] = 1;
605     } else if (clauseKeyword == "num_threads") {
606       if (checkAllowed(numThreadsClause) || parser.parseLParen() ||
607           parser.parseOperand(numThreads.first) ||
608           parser.parseColonType(numThreads.second) || parser.parseRParen())
609         return failure();
610       clauseSegments[pos[numThreadsClause]] = 1;
611     } else if (clauseKeyword == "device") {
612       if (checkAllowed(deviceClause) || parser.parseLParen() ||
613           parser.parseOperand(device.first) ||
614           parser.parseColonType(device.second) || parser.parseRParen())
615         return failure();
616       clauseSegments[pos[deviceClause]] = 1;
617     } else if (clauseKeyword == "thread_limit") {
618       if (checkAllowed(threadLimitClause) || parser.parseLParen() ||
619           parser.parseOperand(threadLimit.first) ||
620           parser.parseColonType(threadLimit.second) || parser.parseRParen())
621         return failure();
622       clauseSegments[pos[threadLimitClause]] = 1;
623     } else if (clauseKeyword == "allocate") {
624       if (checkAllowed(allocateClause) || parser.parseLParen() ||
625           parseAllocateAndAllocator(parser, allocates, allocateTypes,
626                                     allocators, allocatorTypes) ||
627           parser.parseRParen())
628         return failure();
629       clauseSegments[pos[allocateClause]] = allocates.size();
630       clauseSegments[pos[allocateClause] + 1] = allocators.size();
631     } else if (clauseKeyword == "proc_bind") {
632       if (checkAllowed(procBindClause) ||
633           parseClauseAttr<ClauseProcBindKindAttr>(parser, result,
634                                                   "proc_bind_val", "proc bind"))
635         return failure();
636     } else if (clauseKeyword == "reduction") {
637       if (checkAllowed(reductionClause) || parser.parseLParen() ||
638           parseReductionVarList(parser, reductionVars, reductionVarTypes,
639                                 reductions) ||
640           parser.parseRParen())
641         return failure();
642       clauseSegments[pos[reductionClause]] = reductionVars.size();
643     } else if (clauseKeyword == "nowait") {
644       if (checkAllowed(nowaitClause))
645         return failure();
646       auto attr = UnitAttr::get(parser.getBuilder().getContext());
647       result.addAttribute("nowait", attr);
648     } else if (clauseKeyword == "linear") {
649       if (checkAllowed(linearClause) ||
650           parseLinearClause(parser, linears, linearTypes, linearSteps))
651         return failure();
652       clauseSegments[pos[linearClause]] = linears.size();
653       clauseSegments[pos[linearClause] + 1] = linearSteps.size();
654     } else if (clauseKeyword == "schedule") {
655       if (checkAllowed(scheduleClause) ||
656           parseScheduleClause(parser, schedule, modifiers, scheduleChunkSize,
657                               scheduleChunkType))
658         return failure();
659       if (scheduleChunkSize) {
660         clauseSegments[pos[scheduleClause]] = 1;
661       }
662     } else if (clauseKeyword == "collapse") {
663       auto type = parser.getBuilder().getI64Type();
664       mlir::IntegerAttr attr;
665       if (checkAllowed(collapseClause) || parser.parseLParen() ||
666           parser.parseAttribute(attr, type) || parser.parseRParen())
667         return failure();
668       result.addAttribute("collapse_val", attr);
669     } else if (clauseKeyword == "ordered") {
670       mlir::IntegerAttr attr;
671       if (checkAllowed(orderedClause))
672         return failure();
673       if (succeeded(parser.parseOptionalLParen())) {
674         auto type = parser.getBuilder().getI64Type();
675         if (parser.parseAttribute(attr, type) || parser.parseRParen())
676           return failure();
677       } else {
678         // Use 0 to represent no ordered parameter was specified
679         attr = parser.getBuilder().getI64IntegerAttr(0);
680       }
681       result.addAttribute("ordered_val", attr);
682     } else if (clauseKeyword == "order") {
683       if (checkAllowed(orderClause) ||
684           parseClauseAttr<ClauseOrderKindAttr>(parser, result, "order_val",
685                                                "order"))
686         return failure();
687     } else if (clauseKeyword == "memory_order") {
688       if (checkAllowed(memoryOrderClause) ||
689           parseClauseAttr<ClauseMemoryOrderKindAttr>(
690               parser, result, "memory_order", "memory order"))
691         return failure();
692     } else if (clauseKeyword == "hint") {
693       IntegerAttr hint;
694       if (checkAllowed(hintClause) ||
695           parseSynchronizationHint(parser, hint, false))
696         return failure();
697       result.addAttribute("hint", hint);
698     } else {
699       return parser.emitError(parser.getNameLoc())
700              << clauseKeyword << " is not a valid clause";
701     }
702   }
703 
704   // Add if parameter.
705   if (done[ifClause] && clauseSegments[pos[ifClause]] &&
706       failed(
707           parser.resolveOperand(ifCond.first, ifCond.second, result.operands)))
708     return failure();
709 
710   // Add num_threads parameter.
711   if (done[numThreadsClause] && clauseSegments[pos[numThreadsClause]] &&
712       failed(parser.resolveOperand(numThreads.first, numThreads.second,
713                                    result.operands)))
714     return failure();
715 
716   // Add device parameter.
717   if (done[deviceClause] && clauseSegments[pos[deviceClause]] &&
718       failed(
719           parser.resolveOperand(device.first, device.second, result.operands)))
720     return failure();
721 
722   // Add thread_limit parameter.
723   if (done[threadLimitClause] && clauseSegments[pos[threadLimitClause]] &&
724       failed(parser.resolveOperand(threadLimit.first, threadLimit.second,
725                                    result.operands)))
726     return failure();
727 
728   // Add allocate parameters.
729   if (done[allocateClause] && clauseSegments[pos[allocateClause]] &&
730       failed(parser.resolveOperands(allocates, allocateTypes,
731                                     allocates[0].location, result.operands)))
732     return failure();
733 
734   // Add allocator parameters.
735   if (done[allocateClause] && clauseSegments[pos[allocateClause] + 1] &&
736       failed(parser.resolveOperands(allocators, allocatorTypes,
737                                     allocators[0].location, result.operands)))
738     return failure();
739 
740   // Add reduction parameters and symbols
741   if (done[reductionClause] && clauseSegments[pos[reductionClause]]) {
742     if (failed(parser.resolveOperands(reductionVars, reductionVarTypes,
743                                       parser.getNameLoc(), result.operands)))
744       return failure();
745     result.addAttribute("reductions", reductions);
746   }
747 
748   // Add linear parameters
749   if (done[linearClause] && clauseSegments[pos[linearClause]]) {
750     auto linearStepType = parser.getBuilder().getI32Type();
751     SmallVector<Type> linearStepTypes(linearSteps.size(), linearStepType);
752     if (failed(parser.resolveOperands(linears, linearTypes, linears[0].location,
753                                       result.operands)) ||
754         failed(parser.resolveOperands(linearSteps, linearStepTypes,
755                                       linearSteps[0].location,
756                                       result.operands)))
757       return failure();
758   }
759 
760   // Add schedule parameters
761   if (done[scheduleClause] && !schedule.empty()) {
762     schedule[0] = llvm::toUpper(schedule[0]);
763     if (Optional<ClauseScheduleKind> sched =
764             symbolizeClauseScheduleKind(schedule)) {
765       auto attr = ClauseScheduleKindAttr::get(parser.getContext(), *sched);
766       result.addAttribute("schedule_val", attr);
767     } else {
768       return parser.emitError(parser.getCurrentLocation(),
769                               "invalid schedule kind");
770     }
771     if (!modifiers.empty()) {
772       SMLoc loc = parser.getCurrentLocation();
773       if (Optional<ScheduleModifier> mod =
774               symbolizeScheduleModifier(modifiers[0])) {
775         result.addAttribute(
776             "schedule_modifier",
777             ScheduleModifierAttr::get(parser.getContext(), *mod));
778       } else {
779         return parser.emitError(loc, "invalid schedule modifier");
780       }
781       // Only SIMD attribute is allowed here!
782       if (modifiers.size() > 1) {
783         assert(symbolizeScheduleModifier(modifiers[1]) ==
784                ScheduleModifier::simd);
785         auto attr = UnitAttr::get(parser.getBuilder().getContext());
786         result.addAttribute("simd_modifier", attr);
787       }
788     }
789     if (scheduleChunkSize)
790       parser.resolveOperand(*scheduleChunkSize, scheduleChunkType,
791                             result.operands);
792   }
793 
794   segments.insert(segments.end(), clauseSegments.begin(), clauseSegments.end());
795 
796   return success();
797 }
798 
799 //===----------------------------------------------------------------------===//
800 // Verifier for SectionsOp
801 //===----------------------------------------------------------------------===//
802 
803 LogicalResult SectionsOp::verify() {
804   if (allocate_vars().size() != allocators_vars().size())
805     return emitError(
806         "expected equal sizes for allocate and allocator variables");
807 
808   for (auto &inst : *region().begin()) {
809     if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
810       return emitOpError()
811              << "expected omp.section op or terminator op inside region";
812     }
813   }
814 
815   return verifyReductionVarList(*this, reductions(), reduction_vars());
816 }
817 
818 /// Parses an OpenMP Workshare Loop operation
819 ///
820 /// wsloop ::= `omp.wsloop` loop-control clause-list
821 /// loop-control ::= `(` ssa-id-list `)` `:` type `=`  loop-bounds
822 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` inclusive? steps
823 /// steps := `step` `(`ssa-id-list`)`
824 /// clause-list ::= clause clause-list | empty
825 /// clause ::= linear | schedule | collapse | nowait | ordered | order
826 ///          | reduction
827 ParseResult WsLoopOp::parse(OpAsmParser &parser, OperationState &result) {
828   // Parse an opening `(` followed by induction variables followed by `)`
829   SmallVector<OpAsmParser::OperandType> ivs;
830   if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
831                                      OpAsmParser::Delimiter::Paren))
832     return failure();
833 
834   int numIVs = static_cast<int>(ivs.size());
835   Type loopVarType;
836   if (parser.parseColonType(loopVarType))
837     return failure();
838 
839   // Parse loop bounds.
840   SmallVector<OpAsmParser::OperandType> lower;
841   if (parser.parseEqual() ||
842       parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) ||
843       parser.resolveOperands(lower, loopVarType, result.operands))
844     return failure();
845 
846   SmallVector<OpAsmParser::OperandType> upper;
847   if (parser.parseKeyword("to") ||
848       parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) ||
849       parser.resolveOperands(upper, loopVarType, result.operands))
850     return failure();
851 
852   if (succeeded(parser.parseOptionalKeyword("inclusive"))) {
853     auto attr = UnitAttr::get(parser.getBuilder().getContext());
854     result.addAttribute("inclusive", attr);
855   }
856 
857   // Parse step values.
858   SmallVector<OpAsmParser::OperandType> steps;
859   if (parser.parseKeyword("step") ||
860       parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) ||
861       parser.resolveOperands(steps, loopVarType, result.operands))
862     return failure();
863 
864   SmallVector<ClauseType> clauses = {
865       linearClause,  reductionClause, collapseClause, orderClause,
866       orderedClause, nowaitClause,    scheduleClause};
867   SmallVector<int> segments{numIVs, numIVs, numIVs};
868   if (failed(parseClauses(parser, result, clauses, segments)))
869     return failure();
870 
871   result.addAttribute("operand_segment_sizes",
872                       parser.getBuilder().getI32VectorAttr(segments));
873 
874   // Now parse the body.
875   Region *body = result.addRegion();
876   SmallVector<Type> ivTypes(numIVs, loopVarType);
877   SmallVector<OpAsmParser::OperandType> blockArgs(ivs);
878   if (parser.parseRegion(*body, blockArgs, ivTypes))
879     return failure();
880   return success();
881 }
882 
883 void WsLoopOp::print(OpAsmPrinter &p) {
884   auto args = getRegion().front().getArguments();
885   p << " (" << args << ") : " << args[0].getType() << " = (" << lowerBound()
886     << ") to (" << upperBound() << ") ";
887   if (inclusive()) {
888     p << "inclusive ";
889   }
890   p << "step (" << step() << ") ";
891 
892   if (!linear_vars().empty())
893     printLinearClause(p, linear_vars(), linear_step_vars());
894 
895   if (auto sched = schedule_val())
896     printScheduleClause(p, sched.getValue(), schedule_modifier(),
897                         simd_modifier(), schedule_chunk_var());
898 
899   if (auto collapse = collapse_val())
900     p << "collapse(" << collapse << ") ";
901 
902   if (nowait())
903     p << "nowait ";
904 
905   if (auto ordered = ordered_val())
906     p << "ordered(" << ordered << ") ";
907 
908   if (auto order = order_val())
909     p << "order(" << stringifyClauseOrderKind(*order) << ") ";
910 
911   if (!reduction_vars().empty()) {
912     printReductionVarList(p << "reduction(", *this, reduction_vars(),
913                           reduction_vars().getTypes(), reductions());
914     p << ")";
915   }
916 
917   p << ' ';
918   p.printRegion(region(), /*printEntryBlockArgs=*/false);
919 }
920 
921 //===----------------------------------------------------------------------===//
922 // ReductionOp
923 //===----------------------------------------------------------------------===//
924 
925 static ParseResult parseAtomicReductionRegion(OpAsmParser &parser,
926                                               Region &region) {
927   if (parser.parseOptionalKeyword("atomic"))
928     return success();
929   return parser.parseRegion(region);
930 }
931 
932 static void printAtomicReductionRegion(OpAsmPrinter &printer,
933                                        ReductionDeclareOp op, Region &region) {
934   if (region.empty())
935     return;
936   printer << "atomic ";
937   printer.printRegion(region);
938 }
939 
940 LogicalResult ReductionDeclareOp::verify() {
941   if (initializerRegion().empty())
942     return emitOpError() << "expects non-empty initializer region";
943   Block &initializerEntryBlock = initializerRegion().front();
944   if (initializerEntryBlock.getNumArguments() != 1 ||
945       initializerEntryBlock.getArgument(0).getType() != type()) {
946     return emitOpError() << "expects initializer region with one argument "
947                             "of the reduction type";
948   }
949 
950   for (YieldOp yieldOp : initializerRegion().getOps<YieldOp>()) {
951     if (yieldOp.results().size() != 1 ||
952         yieldOp.results().getTypes()[0] != type())
953       return emitOpError() << "expects initializer region to yield a value "
954                               "of the reduction type";
955   }
956 
957   if (reductionRegion().empty())
958     return emitOpError() << "expects non-empty reduction region";
959   Block &reductionEntryBlock = reductionRegion().front();
960   if (reductionEntryBlock.getNumArguments() != 2 ||
961       reductionEntryBlock.getArgumentTypes()[0] !=
962           reductionEntryBlock.getArgumentTypes()[1] ||
963       reductionEntryBlock.getArgumentTypes()[0] != type())
964     return emitOpError() << "expects reduction region with two arguments of "
965                             "the reduction type";
966   for (YieldOp yieldOp : reductionRegion().getOps<YieldOp>()) {
967     if (yieldOp.results().size() != 1 ||
968         yieldOp.results().getTypes()[0] != type())
969       return emitOpError() << "expects reduction region to yield a value "
970                               "of the reduction type";
971   }
972 
973   if (atomicReductionRegion().empty())
974     return success();
975 
976   Block &atomicReductionEntryBlock = atomicReductionRegion().front();
977   if (atomicReductionEntryBlock.getNumArguments() != 2 ||
978       atomicReductionEntryBlock.getArgumentTypes()[0] !=
979           atomicReductionEntryBlock.getArgumentTypes()[1])
980     return emitOpError() << "expects atomic reduction region with two "
981                             "arguments of the same type";
982   auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0]
983                      .dyn_cast<PointerLikeType>();
984   if (!ptrType || ptrType.getElementType() != type())
985     return emitOpError() << "expects atomic reduction region arguments to "
986                             "be accumulators containing the reduction type";
987   return success();
988 }
989 
990 LogicalResult ReductionOp::verify() {
991   // TODO: generalize this to an op interface when there is more than one op
992   // that supports reductions.
993   auto container = (*this)->getParentOfType<WsLoopOp>();
994   for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i)
995     if (container.reduction_vars()[i] == accumulator())
996       return success();
997 
998   return emitOpError() << "the accumulator is not used by the parent";
999 }
1000 
1001 //===----------------------------------------------------------------------===//
1002 // WsLoopOp
1003 //===----------------------------------------------------------------------===//
1004 
1005 void WsLoopOp::build(OpBuilder &builder, OperationState &state,
1006                      ValueRange lowerBound, ValueRange upperBound,
1007                      ValueRange step, ArrayRef<NamedAttribute> attributes) {
1008   build(builder, state, TypeRange(), lowerBound, upperBound, step,
1009         /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
1010         /*reduction_vars=*/ValueRange(), /*schedule_val=*/nullptr,
1011         /*schedule_chunk_var=*/nullptr, /*collapse_val=*/nullptr,
1012         /*nowait=*/nullptr, /*ordered_val=*/nullptr, /*order_val=*/nullptr,
1013         /*inclusive=*/nullptr, /*buildBody=*/false);
1014   state.addAttributes(attributes);
1015 }
1016 
1017 void WsLoopOp::build(OpBuilder &, OperationState &state, TypeRange resultTypes,
1018                      ValueRange operands, ArrayRef<NamedAttribute> attributes) {
1019   state.addOperands(operands);
1020   state.addAttributes(attributes);
1021   (void)state.addRegion();
1022   assert(resultTypes.empty() && "mismatched number of return types");
1023   state.addTypes(resultTypes);
1024 }
1025 
1026 void WsLoopOp::build(OpBuilder &builder, OperationState &result,
1027                      TypeRange typeRange, ValueRange lowerBounds,
1028                      ValueRange upperBounds, ValueRange steps,
1029                      ValueRange linearVars, ValueRange linearStepVars,
1030                      ValueRange reductionVars, StringAttr scheduleVal,
1031                      Value scheduleChunkVar, IntegerAttr collapseVal,
1032                      UnitAttr nowait, IntegerAttr orderedVal,
1033                      StringAttr orderVal, UnitAttr inclusive, bool buildBody) {
1034   result.addOperands(lowerBounds);
1035   result.addOperands(upperBounds);
1036   result.addOperands(steps);
1037   result.addOperands(linearVars);
1038   result.addOperands(linearStepVars);
1039   if (scheduleChunkVar)
1040     result.addOperands(scheduleChunkVar);
1041 
1042   if (scheduleVal)
1043     result.addAttribute("schedule_val", scheduleVal);
1044   if (collapseVal)
1045     result.addAttribute("collapse_val", collapseVal);
1046   if (nowait)
1047     result.addAttribute("nowait", nowait);
1048   if (orderedVal)
1049     result.addAttribute("ordered_val", orderedVal);
1050   if (orderVal)
1051     result.addAttribute("order", orderVal);
1052   if (inclusive)
1053     result.addAttribute("inclusive", inclusive);
1054   result.addAttribute(
1055       WsLoopOp::getOperandSegmentSizeAttr(),
1056       builder.getI32VectorAttr(
1057           {static_cast<int32_t>(lowerBounds.size()),
1058            static_cast<int32_t>(upperBounds.size()),
1059            static_cast<int32_t>(steps.size()),
1060            static_cast<int32_t>(linearVars.size()),
1061            static_cast<int32_t>(linearStepVars.size()),
1062            static_cast<int32_t>(reductionVars.size()),
1063            static_cast<int32_t>(scheduleChunkVar != nullptr ? 1 : 0)}));
1064 
1065   Region *bodyRegion = result.addRegion();
1066   if (buildBody) {
1067     OpBuilder::InsertionGuard guard(builder);
1068     unsigned numIVs = steps.size();
1069     SmallVector<Type, 8> argTypes(numIVs, steps.getType().front());
1070     SmallVector<Location, 8> argLocs(numIVs, result.location);
1071     builder.createBlock(bodyRegion, {}, argTypes, argLocs);
1072   }
1073 }
1074 
1075 LogicalResult WsLoopOp::verify() {
1076   return verifyReductionVarList(*this, reductions(), reduction_vars());
1077 }
1078 
1079 //===----------------------------------------------------------------------===//
1080 // Verifier for critical construct (2.17.1)
1081 //===----------------------------------------------------------------------===//
1082 
1083 LogicalResult CriticalDeclareOp::verify() {
1084   return verifySynchronizationHint(*this, hint());
1085 }
1086 
1087 LogicalResult CriticalOp::verify() {
1088   if (nameAttr()) {
1089     SymbolRefAttr symbolRef = nameAttr();
1090     auto decl = SymbolTable::lookupNearestSymbolFrom<CriticalDeclareOp>(
1091         *this, symbolRef);
1092     if (!decl) {
1093       return emitOpError() << "expected symbol reference " << symbolRef
1094                            << " to point to a critical declaration";
1095     }
1096   }
1097 
1098   return success();
1099 }
1100 
1101 //===----------------------------------------------------------------------===//
1102 // Verifier for ordered construct
1103 //===----------------------------------------------------------------------===//
1104 
1105 LogicalResult OrderedOp::verify() {
1106   auto container = (*this)->getParentOfType<WsLoopOp>();
1107   if (!container || !container.ordered_valAttr() ||
1108       container.ordered_valAttr().getInt() == 0)
1109     return emitOpError() << "ordered depend directive must be closely "
1110                          << "nested inside a worksharing-loop with ordered "
1111                          << "clause with parameter present";
1112 
1113   if (container.ordered_valAttr().getInt() !=
1114       (int64_t)num_loops_val().getValue())
1115     return emitOpError() << "number of variables in depend clause does not "
1116                          << "match number of iteration variables in the "
1117                          << "doacross loop";
1118 
1119   return success();
1120 }
1121 
1122 LogicalResult OrderedRegionOp::verify() {
1123   // TODO: The code generation for ordered simd directive is not supported yet.
1124   if (simd())
1125     return failure();
1126 
1127   if (auto container = (*this)->getParentOfType<WsLoopOp>()) {
1128     if (!container.ordered_valAttr() ||
1129         container.ordered_valAttr().getInt() != 0)
1130       return emitOpError() << "ordered region must be closely nested inside "
1131                            << "a worksharing-loop region with an ordered "
1132                            << "clause without parameter present";
1133   }
1134 
1135   return success();
1136 }
1137 
1138 //===----------------------------------------------------------------------===//
1139 // AtomicReadOp
1140 //===----------------------------------------------------------------------===//
1141 
1142 /// Parser for AtomicReadOp
1143 ///
1144 /// operation ::= `omp.atomic.read` atomic-clause-list address `->` result-type
1145 /// address ::= operand `:` type
1146 ParseResult AtomicReadOp::parse(OpAsmParser &parser, OperationState &result) {
1147   OpAsmParser::OperandType x, v;
1148   Type addressType;
1149   SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
1150   SmallVector<int> segments;
1151 
1152   if (parser.parseOperand(v) || parser.parseEqual() || parser.parseOperand(x) ||
1153       parseClauses(parser, result, clauses, segments) ||
1154       parser.parseColonType(addressType) ||
1155       parser.resolveOperand(x, addressType, result.operands) ||
1156       parser.resolveOperand(v, addressType, result.operands))
1157     return failure();
1158 
1159   return success();
1160 }
1161 
1162 void AtomicReadOp::print(OpAsmPrinter &p) {
1163   p << " " << v() << " = " << x() << " ";
1164   if (auto mo = memory_order())
1165     p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") ";
1166   if (hintAttr())
1167     printSynchronizationHint(p << " ", *this, hintAttr());
1168   p << ": " << x().getType();
1169 }
1170 
1171 /// Verifier for AtomicReadOp
1172 LogicalResult AtomicReadOp::verify() {
1173   if (auto mo = memory_order()) {
1174     if (*mo == ClauseMemoryOrderKind::acq_rel ||
1175         *mo == ClauseMemoryOrderKind::release) {
1176       return emitError(
1177           "memory-order must not be acq_rel or release for atomic reads");
1178     }
1179   }
1180   if (x() == v())
1181     return emitError(
1182         "read and write must not be to the same location for atomic reads");
1183   return verifySynchronizationHint(*this, hint());
1184 }
1185 
1186 //===----------------------------------------------------------------------===//
1187 // AtomicWriteOp
1188 //===----------------------------------------------------------------------===//
1189 
1190 /// Parser for AtomicWriteOp
1191 ///
1192 /// operation ::= `omp.atomic.write` atomic-clause-list operands
1193 /// operands ::= address `,` value
1194 /// address ::= operand `:` type
1195 /// value ::= operand `:` type
1196 ParseResult AtomicWriteOp::parse(OpAsmParser &parser, OperationState &result) {
1197   OpAsmParser::OperandType address, value;
1198   Type addrType, valueType;
1199   SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
1200   SmallVector<int> segments;
1201 
1202   if (parser.parseOperand(address) || parser.parseEqual() ||
1203       parser.parseOperand(value) ||
1204       parseClauses(parser, result, clauses, segments) ||
1205       parser.parseColonType(addrType) || parser.parseComma() ||
1206       parser.parseType(valueType) ||
1207       parser.resolveOperand(address, addrType, result.operands) ||
1208       parser.resolveOperand(value, valueType, result.operands))
1209     return failure();
1210   return success();
1211 }
1212 
1213 void AtomicWriteOp::print(OpAsmPrinter &p) {
1214   p << " " << address() << " = " << value() << " ";
1215   if (auto mo = memory_order())
1216     p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") ";
1217   if (hintAttr())
1218     printSynchronizationHint(p, *this, hintAttr());
1219   p << ": " << address().getType() << ", " << value().getType();
1220 }
1221 
1222 /// Verifier for AtomicWriteOp
1223 LogicalResult AtomicWriteOp::verify() {
1224   if (auto mo = memory_order()) {
1225     if (*mo == ClauseMemoryOrderKind::acq_rel ||
1226         *mo == ClauseMemoryOrderKind::acquire) {
1227       return emitError(
1228           "memory-order must not be acq_rel or acquire for atomic writes");
1229     }
1230   }
1231   return verifySynchronizationHint(*this, hint());
1232 }
1233 
1234 //===----------------------------------------------------------------------===//
1235 // AtomicUpdateOp
1236 //===----------------------------------------------------------------------===//
1237 
1238 /// Parser for AtomicUpdateOp
1239 ///
1240 /// operation ::= `omp.atomic.update` atomic-clause-list ssa-id-and-type region
1241 ParseResult AtomicUpdateOp::parse(OpAsmParser &parser, OperationState &result) {
1242   SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
1243   SmallVector<int> segments;
1244   OpAsmParser::OperandType x, expr;
1245   Type xType;
1246 
1247   if (parseClauses(parser, result, clauses, segments) ||
1248       parser.parseOperand(x) || parser.parseColon() ||
1249       parser.parseType(xType) ||
1250       parser.resolveOperand(x, xType, result.operands) ||
1251       parser.parseRegion(*result.addRegion()))
1252     return failure();
1253   return success();
1254 }
1255 
1256 void AtomicUpdateOp::print(OpAsmPrinter &p) {
1257   p << " ";
1258   if (auto mo = memory_order())
1259     p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") ";
1260   if (hintAttr())
1261     printSynchronizationHint(p, *this, hintAttr());
1262   p << x() << " : " << x().getType();
1263   p.printRegion(region());
1264 }
1265 
1266 /// Verifier for AtomicUpdateOp
1267 LogicalResult AtomicUpdateOp::verify() {
1268   if (auto mo = memory_order()) {
1269     if (*mo == ClauseMemoryOrderKind::acq_rel ||
1270         *mo == ClauseMemoryOrderKind::acquire) {
1271       return emitError(
1272           "memory-order must not be acq_rel or acquire for atomic updates");
1273     }
1274   }
1275 
1276   if (region().getNumArguments() != 1)
1277     return emitError("the region must accept exactly one argument");
1278 
1279   if (x().getType().cast<PointerLikeType>().getElementType() !=
1280       region().getArgument(0).getType()) {
1281     return emitError("the type of the operand must be a pointer type whose "
1282                      "element type is the same as that of the region argument");
1283   }
1284 
1285   YieldOp yieldOp = *region().getOps<YieldOp>().begin();
1286 
1287   if (yieldOp.results().size() != 1)
1288     return emitError("only updated value must be returned");
1289   if (yieldOp.results().front().getType() != region().getArgument(0).getType())
1290     return emitError("input and yielded value must have the same type");
1291   return success();
1292 }
1293 
1294 //===----------------------------------------------------------------------===//
1295 // AtomicCaptureOp
1296 //===----------------------------------------------------------------------===//
1297 
1298 ParseResult AtomicCaptureOp::parse(OpAsmParser &parser,
1299                                    OperationState &result) {
1300   SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
1301   SmallVector<int> segments;
1302   if (parseClauses(parser, result, clauses, segments) ||
1303       parser.parseRegion(*result.addRegion()))
1304     return failure();
1305   return success();
1306 }
1307 
1308 void AtomicCaptureOp::print(OpAsmPrinter &p) {
1309   if (memory_order())
1310     p << "memory_order(" << memory_order() << ") ";
1311   if (hintAttr())
1312     printSynchronizationHint(p, *this, hintAttr());
1313   p.printRegion(region());
1314 }
1315 
1316 /// Verifier for AtomicCaptureOp
1317 LogicalResult AtomicCaptureOp::verify() {
1318   Block::OpListType &ops = region().front().getOperations();
1319   if (ops.size() != 3)
1320     return emitError()
1321            << "expected three operations in omp.atomic.capture region (one "
1322               "terminator, and two atomic ops)";
1323   auto &firstOp = ops.front();
1324   auto &secondOp = *ops.getNextNode(firstOp);
1325   auto firstReadStmt = dyn_cast<AtomicReadOp>(firstOp);
1326   auto firstUpdateStmt = dyn_cast<AtomicUpdateOp>(firstOp);
1327   auto secondReadStmt = dyn_cast<AtomicReadOp>(secondOp);
1328   auto secondUpdateStmt = dyn_cast<AtomicUpdateOp>(secondOp);
1329   auto secondWriteStmt = dyn_cast<AtomicWriteOp>(secondOp);
1330 
1331   if (!((firstUpdateStmt && secondReadStmt) ||
1332         (firstReadStmt && secondUpdateStmt) ||
1333         (firstReadStmt && secondWriteStmt)))
1334     return ops.front().emitError()
1335            << "invalid sequence of operations in the capture region";
1336   if (firstUpdateStmt && secondReadStmt &&
1337       firstUpdateStmt.x() != secondReadStmt.x())
1338     return firstUpdateStmt.emitError()
1339            << "updated variable in omp.atomic.update must be captured in "
1340               "second operation";
1341   if (firstReadStmt && secondUpdateStmt &&
1342       firstReadStmt.x() != secondUpdateStmt.x())
1343     return firstReadStmt.emitError()
1344            << "captured variable in omp.atomic.read must be updated in second "
1345               "operation";
1346   if (firstReadStmt && secondWriteStmt &&
1347       firstReadStmt.x() != secondWriteStmt.address())
1348     return firstReadStmt.emitError()
1349            << "captured variable in omp.atomic.read must be updated in "
1350               "second operation";
1351   return success();
1352 }
1353 
1354 #define GET_ATTRDEF_CLASSES
1355 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
1356 
1357 #define GET_OP_CLASSES
1358 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
1359