1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the OpenMP dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
15 #include "mlir/IR/Attributes.h"
16 #include "mlir/IR/DialectImplementation.h"
17 #include "mlir/IR/OpImplementation.h"
18 #include "mlir/IR/OperationSupport.h"
19 
20 #include "llvm/ADT/BitVector.h"
21 #include "llvm/ADT/SmallString.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/ADT/StringSwitch.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include <cstddef>
27 
28 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
29 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
30 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
31 
32 using namespace mlir;
33 using namespace mlir::omp;
34 
35 namespace {
36 /// Model for pointer-like types that already provide a `getElementType` method.
37 template <typename T>
38 struct PointerLikeModel
39     : public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> {
40   Type getElementType(Type pointer) const {
41     return pointer.cast<T>().getElementType();
42   }
43 };
44 } // namespace
45 
46 void OpenMPDialect::initialize() {
47   addOperations<
48 #define GET_OP_LIST
49 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
50       >();
51   addAttributes<
52 #define GET_ATTRDEF_LIST
53 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
54       >();
55 
56   LLVM::LLVMPointerType::attachInterface<
57       PointerLikeModel<LLVM::LLVMPointerType>>(*getContext());
58   MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext());
59 }
60 
61 //===----------------------------------------------------------------------===//
62 // ParallelOp
63 //===----------------------------------------------------------------------===//
64 
65 void ParallelOp::build(OpBuilder &builder, OperationState &state,
66                        ArrayRef<NamedAttribute> attributes) {
67   ParallelOp::build(
68       builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
69       /*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(),
70       /*proc_bind_val=*/nullptr);
71   state.addAttributes(attributes);
72 }
73 
74 //===----------------------------------------------------------------------===//
75 // Parser and printer for Allocate Clause
76 //===----------------------------------------------------------------------===//
77 
78 /// Parse an allocate clause with allocators and a list of operands with types.
79 ///
80 /// allocate-operand-list :: = allocate-operand |
81 ///                            allocator-operand `,` allocate-operand-list
82 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
83 /// ssa-id-and-type ::= ssa-id `:` type
84 static ParseResult parseAllocateAndAllocator(
85     OpAsmParser &parser,
86     SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate,
87     SmallVectorImpl<Type> &typesAllocate,
88     SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator,
89     SmallVectorImpl<Type> &typesAllocator) {
90 
91   return parser.parseCommaSeparatedList([&]() -> ParseResult {
92     OpAsmParser::OperandType operand;
93     Type type;
94     if (parser.parseOperand(operand) || parser.parseColonType(type))
95       return failure();
96     operandsAllocator.push_back(operand);
97     typesAllocator.push_back(type);
98     if (parser.parseArrow())
99       return failure();
100     if (parser.parseOperand(operand) || parser.parseColonType(type))
101       return failure();
102 
103     operandsAllocate.push_back(operand);
104     typesAllocate.push_back(type);
105     return success();
106   });
107 }
108 
109 /// Print allocate clause
110 static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op,
111                                       OperandRange varsAllocate,
112                                       TypeRange typesAllocate,
113                                       OperandRange varsAllocator,
114                                       TypeRange typesAllocator) {
115   for (unsigned i = 0; i < varsAllocate.size(); ++i) {
116     std::string separator = i == varsAllocate.size() - 1 ? "" : ", ";
117     p << varsAllocator[i] << " : " << typesAllocator[i] << " -> ";
118     p << varsAllocate[i] << " : " << typesAllocate[i] << separator;
119   }
120 }
121 
122 /// Parse a clause attribute (StringEnumAttr)
123 template <typename ClauseAttr>
124 static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) {
125   using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
126   StringRef enumStr;
127   SMLoc loc = parser.getCurrentLocation();
128   if (parser.parseKeyword(&enumStr))
129     return failure();
130   if (Optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
131     attr = ClauseAttr::get(parser.getContext(), *enumValue);
132     return success();
133   }
134   return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
135 }
136 
137 template <typename ClauseAttr>
138 void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
139   p << stringifyEnum(attr.getValue());
140 }
141 
142 //===----------------------------------------------------------------------===//
143 // Parser and printer for Procbind Clause
144 //===----------------------------------------------------------------------===//
145 
146 ParseResult parseProcBindKind(OpAsmParser &parser,
147                               omp::ClauseProcBindKindAttr &procBindAttr) {
148   StringRef procBindStr;
149   if (parser.parseKeyword(&procBindStr))
150     return failure();
151   if (auto procBindVal = symbolizeClauseProcBindKind(procBindStr)) {
152     procBindAttr =
153         ClauseProcBindKindAttr::get(parser.getContext(), *procBindVal);
154     return success();
155   }
156   return failure();
157 }
158 
159 void printProcBindKind(OpAsmPrinter &p, Operation *op,
160                        omp::ClauseProcBindKindAttr procBindAttr) {
161   p << stringifyClauseProcBindKind(procBindAttr.getValue());
162 }
163 
164 LogicalResult ParallelOp::verify() {
165   if (allocate_vars().size() != allocators_vars().size())
166     return emitError(
167         "expected equal sizes for allocate and allocator variables");
168   return success();
169 }
170 
171 //===----------------------------------------------------------------------===//
172 // Parser and printer for Linear Clause
173 //===----------------------------------------------------------------------===//
174 
175 /// linear ::= `linear` `(` linear-list `)`
176 /// linear-list := linear-val | linear-val linear-list
177 /// linear-val := ssa-id-and-type `=` ssa-id-and-type
178 static ParseResult
179 parseLinearClause(OpAsmParser &parser,
180                   SmallVectorImpl<OpAsmParser::OperandType> &vars,
181                   SmallVectorImpl<Type> &types,
182                   SmallVectorImpl<OpAsmParser::OperandType> &stepVars) {
183   if (parser.parseLParen())
184     return failure();
185 
186   do {
187     OpAsmParser::OperandType var;
188     Type type;
189     OpAsmParser::OperandType stepVar;
190     if (parser.parseOperand(var) || parser.parseEqual() ||
191         parser.parseOperand(stepVar) || parser.parseColonType(type))
192       return failure();
193 
194     vars.push_back(var);
195     types.push_back(type);
196     stepVars.push_back(stepVar);
197   } while (succeeded(parser.parseOptionalComma()));
198 
199   if (parser.parseRParen())
200     return failure();
201 
202   return success();
203 }
204 
205 /// Print Linear Clause
206 static void printLinearClause(OpAsmPrinter &p, OperandRange linearVars,
207                               OperandRange linearStepVars) {
208   size_t linearVarsSize = linearVars.size();
209   p << "linear(";
210   for (unsigned i = 0; i < linearVarsSize; ++i) {
211     std::string separator = i == linearVarsSize - 1 ? ") " : ", ";
212     p << linearVars[i];
213     if (linearStepVars.size() > i)
214       p << " = " << linearStepVars[i];
215     p << " : " << linearVars[i].getType() << separator;
216   }
217 }
218 
219 //===----------------------------------------------------------------------===//
220 // Parser, printer and verifier for Schedule Clause
221 //===----------------------------------------------------------------------===//
222 
223 static ParseResult
224 verifyScheduleModifiers(OpAsmParser &parser,
225                         SmallVectorImpl<SmallString<12>> &modifiers) {
226   if (modifiers.size() > 2)
227     return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)";
228   for (const auto &mod : modifiers) {
229     // Translate the string. If it has no value, then it was not a valid
230     // modifier!
231     auto symbol = symbolizeScheduleModifier(mod);
232     if (!symbol.hasValue())
233       return parser.emitError(parser.getNameLoc())
234              << " unknown modifier type: " << mod;
235   }
236 
237   // If we have one modifier that is "simd", then stick a "none" modiifer in
238   // index 0.
239   if (modifiers.size() == 1) {
240     if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
241       modifiers.push_back(modifiers[0]);
242       modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
243     }
244   } else if (modifiers.size() == 2) {
245     // If there are two modifier:
246     // First modifier should not be simd, second one should be simd
247     if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
248         symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
249       return parser.emitError(parser.getNameLoc())
250              << " incorrect modifier order";
251   }
252   return success();
253 }
254 
255 /// schedule ::= `schedule` `(` sched-list `)`
256 /// sched-list ::= sched-val | sched-val sched-list |
257 ///                sched-val `,` sched-modifier
258 /// sched-val ::= sched-with-chunk | sched-wo-chunk
259 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
260 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
261 /// sched-wo-chunk ::=  `auto` | `runtime`
262 /// sched-modifier ::=  sched-mod-val | sched-mod-val `,` sched-mod-val
263 /// sched-mod-val ::=  `monotonic` | `nonmonotonic` | `simd` | `none`
264 static ParseResult
265 parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule,
266                     SmallVectorImpl<SmallString<12>> &modifiers,
267                     Optional<OpAsmParser::OperandType> &chunkSize,
268                     Type &chunkType) {
269   if (parser.parseLParen())
270     return failure();
271 
272   StringRef keyword;
273   if (parser.parseKeyword(&keyword))
274     return failure();
275 
276   schedule = keyword;
277   if (keyword == "static" || keyword == "dynamic" || keyword == "guided") {
278     if (succeeded(parser.parseOptionalEqual())) {
279       chunkSize = OpAsmParser::OperandType{};
280       if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType))
281         return failure();
282     } else {
283       chunkSize = llvm::NoneType::None;
284     }
285   } else if (keyword == "auto" || keyword == "runtime") {
286     chunkSize = llvm::NoneType::None;
287   } else {
288     return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
289   }
290 
291   // If there is a comma, we have one or more modifiers..
292   while (succeeded(parser.parseOptionalComma())) {
293     StringRef mod;
294     if (parser.parseKeyword(&mod))
295       return failure();
296     modifiers.push_back(mod);
297   }
298 
299   if (parser.parseRParen())
300     return failure();
301 
302   if (verifyScheduleModifiers(parser, modifiers))
303     return failure();
304 
305   return success();
306 }
307 
308 /// Print schedule clause
309 static void printScheduleClause(OpAsmPrinter &p, ClauseScheduleKind sched,
310                                 Optional<ScheduleModifier> modifier, bool simd,
311                                 Value scheduleChunkVar) {
312   p << "schedule(" << stringifyClauseScheduleKind(sched).lower();
313   if (scheduleChunkVar)
314     p << " = " << scheduleChunkVar << " : " << scheduleChunkVar.getType();
315   if (modifier)
316     p << ", " << stringifyScheduleModifier(*modifier);
317   if (simd)
318     p << ", simd";
319   p << ") ";
320 }
321 
322 //===----------------------------------------------------------------------===//
323 // Parser, printer and verifier for ReductionVarList
324 //===----------------------------------------------------------------------===//
325 
326 /// reduction-entry-list ::= reduction-entry
327 ///                        | reduction-entry-list `,` reduction-entry
328 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type
329 static ParseResult parseReductionVarList(
330     OpAsmParser &parser, SmallVectorImpl<OpAsmParser::OperandType> &operands,
331     SmallVectorImpl<Type> &types, ArrayAttr &redcuctionSymbols) {
332   SmallVector<SymbolRefAttr> reductionVec;
333   do {
334     if (parser.parseAttribute(reductionVec.emplace_back()) ||
335         parser.parseArrow() || parser.parseOperand(operands.emplace_back()) ||
336         parser.parseColonType(types.emplace_back()))
337       return failure();
338   } while (succeeded(parser.parseOptionalComma()));
339   SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
340   redcuctionSymbols = ArrayAttr::get(parser.getContext(), reductions);
341   return success();
342 }
343 
344 /// Print Reduction clause
345 static void printReductionVarList(OpAsmPrinter &p, Operation *op,
346                                   OperandRange reductionVars,
347                                   TypeRange reductionTypes,
348                                   Optional<ArrayAttr> reductions) {
349   for (unsigned i = 0, e = reductions->size(); i < e; ++i) {
350     if (i != 0)
351       p << ", ";
352     p << (*reductions)[i] << " -> " << reductionVars[i] << " : "
353       << reductionVars[i].getType();
354   }
355 }
356 
357 /// Verifies Reduction Clause
358 static LogicalResult verifyReductionVarList(Operation *op,
359                                             Optional<ArrayAttr> reductions,
360                                             OperandRange reductionVars) {
361   if (!reductionVars.empty()) {
362     if (!reductions || reductions->size() != reductionVars.size())
363       return op->emitOpError()
364              << "expected as many reduction symbol references "
365                 "as reduction variables";
366   } else {
367     if (reductions)
368       return op->emitOpError() << "unexpected reduction symbol references";
369     return success();
370   }
371 
372   DenseSet<Value> accumulators;
373   for (auto args : llvm::zip(reductionVars, *reductions)) {
374     Value accum = std::get<0>(args);
375 
376     if (!accumulators.insert(accum).second)
377       return op->emitOpError() << "accumulator variable used more than once";
378 
379     Type varType = accum.getType().cast<PointerLikeType>();
380     auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>();
381     auto decl =
382         SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef);
383     if (!decl)
384       return op->emitOpError() << "expected symbol reference " << symbolRef
385                                << " to point to a reduction declaration";
386 
387     if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
388       return op->emitOpError()
389              << "expected accumulator (" << varType
390              << ") to be the same type as reduction declaration ("
391              << decl.getAccumulatorType() << ")";
392   }
393 
394   return success();
395 }
396 
397 //===----------------------------------------------------------------------===//
398 // Parser, printer and verifier for Synchronization Hint (2.17.12)
399 //===----------------------------------------------------------------------===//
400 
401 /// Parses a Synchronization Hint clause. The value of hint is an integer
402 /// which is a combination of different hints from `omp_sync_hint_t`.
403 ///
404 /// hint-clause = `hint` `(` hint-value `)`
405 static ParseResult parseSynchronizationHint(OpAsmParser &parser,
406                                             IntegerAttr &hintAttr) {
407   StringRef hintKeyword;
408   int64_t hint = 0;
409   do {
410     if (failed(parser.parseKeyword(&hintKeyword)))
411       return failure();
412     if (hintKeyword == "uncontended")
413       hint |= 1;
414     else if (hintKeyword == "contended")
415       hint |= 2;
416     else if (hintKeyword == "nonspeculative")
417       hint |= 4;
418     else if (hintKeyword == "speculative")
419       hint |= 8;
420     else
421       return parser.emitError(parser.getCurrentLocation())
422              << hintKeyword << " is not a valid hint";
423   } while (succeeded(parser.parseOptionalComma()));
424   hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
425   return success();
426 }
427 
428 /// Prints a Synchronization Hint clause
429 static void printSynchronizationHint(OpAsmPrinter &p, Operation *op,
430                                      IntegerAttr hintAttr) {
431   int64_t hint = hintAttr.getInt();
432 
433   if (hint == 0)
434     return;
435 
436   // Helper function to get n-th bit from the right end of `value`
437   auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
438 
439   bool uncontended = bitn(hint, 0);
440   bool contended = bitn(hint, 1);
441   bool nonspeculative = bitn(hint, 2);
442   bool speculative = bitn(hint, 3);
443 
444   SmallVector<StringRef> hints;
445   if (uncontended)
446     hints.push_back("uncontended");
447   if (contended)
448     hints.push_back("contended");
449   if (nonspeculative)
450     hints.push_back("nonspeculative");
451   if (speculative)
452     hints.push_back("speculative");
453 
454   llvm::interleaveComma(hints, p);
455 }
456 
457 /// Verifies a synchronization hint clause
458 static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
459 
460   // Helper function to get n-th bit from the right end of `value`
461   auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
462 
463   bool uncontended = bitn(hint, 0);
464   bool contended = bitn(hint, 1);
465   bool nonspeculative = bitn(hint, 2);
466   bool speculative = bitn(hint, 3);
467 
468   if (uncontended && contended)
469     return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
470                                 "omp_sync_hint_contended cannot be combined";
471   if (nonspeculative && speculative)
472     return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
473                                 "omp_sync_hint_speculative cannot be combined.";
474   return success();
475 }
476 
477 enum ClauseType {
478   allocateClause,
479   reductionClause,
480   nowaitClause,
481   linearClause,
482   scheduleClause,
483   collapseClause,
484   orderClause,
485   orderedClause,
486   COUNT
487 };
488 
489 //===----------------------------------------------------------------------===//
490 // Parser for Clause List
491 //===----------------------------------------------------------------------===//
492 
493 /// Parse a list of clauses. The clauses can appear in any order, but their
494 /// operand segment indices are in the same order that they are passed in the
495 /// `clauses` list. The operand segments are added over the prevSegments
496 
497 /// clause-list ::= clause clause-list | empty
498 /// clause ::= allocate | reduction | nowait | linear | schedule | collapse
499 ///          | order | ordered
500 /// allocate ::= `allocate` `(` allocate-operand-list `)`
501 /// reduction ::= `reduction` `(` reduction-entry-list `)`
502 /// nowait ::= `nowait`
503 /// linear ::= `linear` `(` linear-list `)`
504 /// schedule ::= `schedule` `(` sched-list `)`
505 /// collapse ::= `collapse` `(` ssa-id-and-type `)`
506 /// order ::= `order` `(` `concurrent` `)`
507 /// ordered ::= `ordered` `(` ssa-id-and-type `)`
508 ///
509 /// Note that each clause can only appear once in the clase-list.
510 static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
511                                 SmallVectorImpl<ClauseType> &clauses,
512                                 SmallVectorImpl<int> &segments) {
513 
514   // Check done[clause] to see if it has been parsed already
515   BitVector done(ClauseType::COUNT, false);
516 
517   // See pos[clause] to get position of clause in operand segments
518   SmallVector<int> pos(ClauseType::COUNT, -1);
519 
520   // Stores the last parsed clause keyword
521   StringRef clauseKeyword;
522   StringRef opName = result.name.getStringRef();
523 
524   // Containers for storing operands, types and attributes for various clauses
525   SmallVector<OpAsmParser::OperandType> allocates, allocators;
526   SmallVector<Type> allocateTypes, allocatorTypes;
527 
528   ArrayAttr reductions;
529   SmallVector<OpAsmParser::OperandType> reductionVars;
530   SmallVector<Type> reductionVarTypes;
531 
532   SmallVector<OpAsmParser::OperandType> linears;
533   SmallVector<Type> linearTypes;
534   SmallVector<OpAsmParser::OperandType> linearSteps;
535 
536   SmallString<8> schedule;
537   SmallVector<SmallString<12>> modifiers;
538   Optional<OpAsmParser::OperandType> scheduleChunkSize;
539   Type scheduleChunkType;
540 
541   // Compute the position of clauses in operand segments
542   int currPos = 0;
543   for (ClauseType clause : clauses) {
544 
545     // Skip the following clauses - they do not take any position in operand
546     // segments
547     if (clause == nowaitClause || clause == collapseClause ||
548         clause == orderClause || clause == orderedClause)
549       continue;
550 
551     pos[clause] = currPos++;
552 
553     // For the following clauses, two positions are reserved in the operand
554     // segments
555     if (clause == allocateClause || clause == linearClause)
556       currPos++;
557   }
558 
559   SmallVector<int> clauseSegments(currPos);
560 
561   // Helper function to check if a clause is allowed/repeated or not
562   auto checkAllowed = [&](ClauseType clause) -> ParseResult {
563     if (!llvm::is_contained(clauses, clause))
564       return parser.emitError(parser.getCurrentLocation())
565              << clauseKeyword << " is not a valid clause for the " << opName
566              << " operation";
567     if (done[clause])
568       return parser.emitError(parser.getCurrentLocation())
569              << "at most one " << clauseKeyword << " clause can appear on the "
570              << opName << " operation";
571     done[clause] = true;
572     return success();
573   };
574 
575   while (succeeded(parser.parseOptionalKeyword(&clauseKeyword))) {
576     if (clauseKeyword == "allocate") {
577       if (checkAllowed(allocateClause) || parser.parseLParen() ||
578           parseAllocateAndAllocator(parser, allocates, allocateTypes,
579                                     allocators, allocatorTypes) ||
580           parser.parseRParen())
581         return failure();
582       clauseSegments[pos[allocateClause]] = allocates.size();
583       clauseSegments[pos[allocateClause] + 1] = allocators.size();
584     } else if (clauseKeyword == "reduction") {
585       if (checkAllowed(reductionClause) || parser.parseLParen() ||
586           parseReductionVarList(parser, reductionVars, reductionVarTypes,
587                                 reductions) ||
588           parser.parseRParen())
589         return failure();
590       clauseSegments[pos[reductionClause]] = reductionVars.size();
591     } else if (clauseKeyword == "nowait") {
592       if (checkAllowed(nowaitClause))
593         return failure();
594       auto attr = UnitAttr::get(parser.getBuilder().getContext());
595       result.addAttribute("nowait", attr);
596     } else if (clauseKeyword == "linear") {
597       if (checkAllowed(linearClause) ||
598           parseLinearClause(parser, linears, linearTypes, linearSteps))
599         return failure();
600       clauseSegments[pos[linearClause]] = linears.size();
601       clauseSegments[pos[linearClause] + 1] = linearSteps.size();
602     } else if (clauseKeyword == "schedule") {
603       if (checkAllowed(scheduleClause) ||
604           parseScheduleClause(parser, schedule, modifiers, scheduleChunkSize,
605                               scheduleChunkType))
606         return failure();
607       if (scheduleChunkSize) {
608         clauseSegments[pos[scheduleClause]] = 1;
609       }
610     } else if (clauseKeyword == "collapse") {
611       auto type = parser.getBuilder().getI64Type();
612       mlir::IntegerAttr attr;
613       if (checkAllowed(collapseClause) || parser.parseLParen() ||
614           parser.parseAttribute(attr, type) || parser.parseRParen())
615         return failure();
616       result.addAttribute("collapse_val", attr);
617     } else if (clauseKeyword == "ordered") {
618       mlir::IntegerAttr attr;
619       if (checkAllowed(orderedClause))
620         return failure();
621       if (succeeded(parser.parseOptionalLParen())) {
622         auto type = parser.getBuilder().getI64Type();
623         if (parser.parseAttribute(attr, type) || parser.parseRParen())
624           return failure();
625       } else {
626         // Use 0 to represent no ordered parameter was specified
627         attr = parser.getBuilder().getI64IntegerAttr(0);
628       }
629       result.addAttribute("ordered_val", attr);
630     } else if (clauseKeyword == "order") {
631       ClauseOrderKindAttr order;
632       if (checkAllowed(orderClause) || parser.parseLParen() ||
633           parseClauseAttr<ClauseOrderKindAttr>(parser, order) ||
634           parser.parseRParen())
635         return failure();
636       result.addAttribute("order_val", order);
637     } else {
638       return parser.emitError(parser.getNameLoc())
639              << clauseKeyword << " is not a valid clause";
640     }
641   }
642 
643   // Add allocate parameters.
644   if (done[allocateClause] && clauseSegments[pos[allocateClause]] &&
645       failed(parser.resolveOperands(allocates, allocateTypes,
646                                     allocates[0].location, result.operands)))
647     return failure();
648 
649   // Add allocator parameters.
650   if (done[allocateClause] && clauseSegments[pos[allocateClause] + 1] &&
651       failed(parser.resolveOperands(allocators, allocatorTypes,
652                                     allocators[0].location, result.operands)))
653     return failure();
654 
655   // Add reduction parameters and symbols
656   if (done[reductionClause] && clauseSegments[pos[reductionClause]]) {
657     if (failed(parser.resolveOperands(reductionVars, reductionVarTypes,
658                                       parser.getNameLoc(), result.operands)))
659       return failure();
660     result.addAttribute("reductions", reductions);
661   }
662 
663   // Add linear parameters
664   if (done[linearClause] && clauseSegments[pos[linearClause]]) {
665     auto linearStepType = parser.getBuilder().getI32Type();
666     SmallVector<Type> linearStepTypes(linearSteps.size(), linearStepType);
667     if (failed(parser.resolveOperands(linears, linearTypes, linears[0].location,
668                                       result.operands)) ||
669         failed(parser.resolveOperands(linearSteps, linearStepTypes,
670                                       linearSteps[0].location,
671                                       result.operands)))
672       return failure();
673   }
674 
675   // Add schedule parameters
676   if (done[scheduleClause] && !schedule.empty()) {
677     schedule[0] = llvm::toUpper(schedule[0]);
678     if (Optional<ClauseScheduleKind> sched =
679             symbolizeClauseScheduleKind(schedule)) {
680       auto attr = ClauseScheduleKindAttr::get(parser.getContext(), *sched);
681       result.addAttribute("schedule_val", attr);
682     } else {
683       return parser.emitError(parser.getCurrentLocation(),
684                               "invalid schedule kind");
685     }
686     if (!modifiers.empty()) {
687       SMLoc loc = parser.getCurrentLocation();
688       if (Optional<ScheduleModifier> mod =
689               symbolizeScheduleModifier(modifiers[0])) {
690         result.addAttribute(
691             "schedule_modifier",
692             ScheduleModifierAttr::get(parser.getContext(), *mod));
693       } else {
694         return parser.emitError(loc, "invalid schedule modifier");
695       }
696       // Only SIMD attribute is allowed here!
697       if (modifiers.size() > 1) {
698         assert(symbolizeScheduleModifier(modifiers[1]) ==
699                ScheduleModifier::simd);
700         auto attr = UnitAttr::get(parser.getBuilder().getContext());
701         result.addAttribute("simd_modifier", attr);
702       }
703     }
704     if (scheduleChunkSize)
705       parser.resolveOperand(*scheduleChunkSize, scheduleChunkType,
706                             result.operands);
707   }
708 
709   segments.insert(segments.end(), clauseSegments.begin(), clauseSegments.end());
710 
711   return success();
712 }
713 
714 //===----------------------------------------------------------------------===//
715 // Verifier for SectionsOp
716 //===----------------------------------------------------------------------===//
717 
718 LogicalResult SectionsOp::verify() {
719   if (allocate_vars().size() != allocators_vars().size())
720     return emitError(
721         "expected equal sizes for allocate and allocator variables");
722 
723   for (auto &inst : *region().begin()) {
724     if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
725       return emitOpError()
726              << "expected omp.section op or terminator op inside region";
727     }
728   }
729 
730   return verifyReductionVarList(*this, reductions(), reduction_vars());
731 }
732 
733 /// Parses an OpenMP Workshare Loop operation
734 ///
735 /// wsloop ::= `omp.wsloop` loop-control clause-list
736 /// loop-control ::= `(` ssa-id-list `)` `:` type `=`  loop-bounds
737 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` inclusive? steps
738 /// steps := `step` `(`ssa-id-list`)`
739 /// clause-list ::= clause clause-list | empty
740 /// clause ::= linear | schedule | collapse | nowait | ordered | order
741 ///          | reduction
742 ParseResult WsLoopOp::parse(OpAsmParser &parser, OperationState &result) {
743   // Parse an opening `(` followed by induction variables followed by `)`
744   SmallVector<OpAsmParser::OperandType> ivs;
745   if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
746                                      OpAsmParser::Delimiter::Paren))
747     return failure();
748 
749   int numIVs = static_cast<int>(ivs.size());
750   Type loopVarType;
751   if (parser.parseColonType(loopVarType))
752     return failure();
753 
754   // Parse loop bounds.
755   SmallVector<OpAsmParser::OperandType> lower;
756   if (parser.parseEqual() ||
757       parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) ||
758       parser.resolveOperands(lower, loopVarType, result.operands))
759     return failure();
760 
761   SmallVector<OpAsmParser::OperandType> upper;
762   if (parser.parseKeyword("to") ||
763       parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) ||
764       parser.resolveOperands(upper, loopVarType, result.operands))
765     return failure();
766 
767   if (succeeded(parser.parseOptionalKeyword("inclusive"))) {
768     auto attr = UnitAttr::get(parser.getBuilder().getContext());
769     result.addAttribute("inclusive", attr);
770   }
771 
772   // Parse step values.
773   SmallVector<OpAsmParser::OperandType> steps;
774   if (parser.parseKeyword("step") ||
775       parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) ||
776       parser.resolveOperands(steps, loopVarType, result.operands))
777     return failure();
778 
779   SmallVector<ClauseType> clauses = {
780       linearClause,  reductionClause, collapseClause, orderClause,
781       orderedClause, nowaitClause,    scheduleClause};
782   SmallVector<int> segments{numIVs, numIVs, numIVs};
783   if (failed(parseClauses(parser, result, clauses, segments)))
784     return failure();
785 
786   result.addAttribute("operand_segment_sizes",
787                       parser.getBuilder().getI32VectorAttr(segments));
788 
789   // Now parse the body.
790   Region *body = result.addRegion();
791   SmallVector<Type> ivTypes(numIVs, loopVarType);
792   SmallVector<OpAsmParser::OperandType> blockArgs(ivs);
793   if (parser.parseRegion(*body, blockArgs, ivTypes))
794     return failure();
795   return success();
796 }
797 
798 void WsLoopOp::print(OpAsmPrinter &p) {
799   auto args = getRegion().front().getArguments();
800   p << " (" << args << ") : " << args[0].getType() << " = (" << lowerBound()
801     << ") to (" << upperBound() << ") ";
802   if (inclusive()) {
803     p << "inclusive ";
804   }
805   p << "step (" << step() << ") ";
806 
807   if (!linear_vars().empty())
808     printLinearClause(p, linear_vars(), linear_step_vars());
809 
810   if (auto sched = schedule_val())
811     printScheduleClause(p, sched.getValue(), schedule_modifier(),
812                         simd_modifier(), schedule_chunk_var());
813 
814   if (auto collapse = collapse_val())
815     p << "collapse(" << collapse << ") ";
816 
817   if (nowait())
818     p << "nowait ";
819 
820   if (auto ordered = ordered_val())
821     p << "ordered(" << ordered << ") ";
822 
823   if (auto order = order_val())
824     p << "order(" << stringifyClauseOrderKind(*order) << ") ";
825 
826   if (!reduction_vars().empty()) {
827     printReductionVarList(p << "reduction(", *this, reduction_vars(),
828                           reduction_vars().getTypes(), reductions());
829     p << ")";
830   }
831 
832   p << ' ';
833   p.printRegion(region(), /*printEntryBlockArgs=*/false);
834 }
835 
836 //===----------------------------------------------------------------------===//
837 // ReductionOp
838 //===----------------------------------------------------------------------===//
839 
840 static ParseResult parseAtomicReductionRegion(OpAsmParser &parser,
841                                               Region &region) {
842   if (parser.parseOptionalKeyword("atomic"))
843     return success();
844   return parser.parseRegion(region);
845 }
846 
847 static void printAtomicReductionRegion(OpAsmPrinter &printer,
848                                        ReductionDeclareOp op, Region &region) {
849   if (region.empty())
850     return;
851   printer << "atomic ";
852   printer.printRegion(region);
853 }
854 
855 LogicalResult ReductionDeclareOp::verify() {
856   if (initializerRegion().empty())
857     return emitOpError() << "expects non-empty initializer region";
858   Block &initializerEntryBlock = initializerRegion().front();
859   if (initializerEntryBlock.getNumArguments() != 1 ||
860       initializerEntryBlock.getArgument(0).getType() != type()) {
861     return emitOpError() << "expects initializer region with one argument "
862                             "of the reduction type";
863   }
864 
865   for (YieldOp yieldOp : initializerRegion().getOps<YieldOp>()) {
866     if (yieldOp.results().size() != 1 ||
867         yieldOp.results().getTypes()[0] != type())
868       return emitOpError() << "expects initializer region to yield a value "
869                               "of the reduction type";
870   }
871 
872   if (reductionRegion().empty())
873     return emitOpError() << "expects non-empty reduction region";
874   Block &reductionEntryBlock = reductionRegion().front();
875   if (reductionEntryBlock.getNumArguments() != 2 ||
876       reductionEntryBlock.getArgumentTypes()[0] !=
877           reductionEntryBlock.getArgumentTypes()[1] ||
878       reductionEntryBlock.getArgumentTypes()[0] != type())
879     return emitOpError() << "expects reduction region with two arguments of "
880                             "the reduction type";
881   for (YieldOp yieldOp : reductionRegion().getOps<YieldOp>()) {
882     if (yieldOp.results().size() != 1 ||
883         yieldOp.results().getTypes()[0] != type())
884       return emitOpError() << "expects reduction region to yield a value "
885                               "of the reduction type";
886   }
887 
888   if (atomicReductionRegion().empty())
889     return success();
890 
891   Block &atomicReductionEntryBlock = atomicReductionRegion().front();
892   if (atomicReductionEntryBlock.getNumArguments() != 2 ||
893       atomicReductionEntryBlock.getArgumentTypes()[0] !=
894           atomicReductionEntryBlock.getArgumentTypes()[1])
895     return emitOpError() << "expects atomic reduction region with two "
896                             "arguments of the same type";
897   auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0]
898                      .dyn_cast<PointerLikeType>();
899   if (!ptrType || ptrType.getElementType() != type())
900     return emitOpError() << "expects atomic reduction region arguments to "
901                             "be accumulators containing the reduction type";
902   return success();
903 }
904 
905 LogicalResult ReductionOp::verify() {
906   // TODO: generalize this to an op interface when there is more than one op
907   // that supports reductions.
908   auto container = (*this)->getParentOfType<WsLoopOp>();
909   for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i)
910     if (container.reduction_vars()[i] == accumulator())
911       return success();
912 
913   return emitOpError() << "the accumulator is not used by the parent";
914 }
915 
916 //===----------------------------------------------------------------------===//
917 // WsLoopOp
918 //===----------------------------------------------------------------------===//
919 
920 void WsLoopOp::build(OpBuilder &builder, OperationState &state,
921                      ValueRange lowerBound, ValueRange upperBound,
922                      ValueRange step, ArrayRef<NamedAttribute> attributes) {
923   build(builder, state, lowerBound, upperBound, step,
924         /*linear_vars=*/ValueRange(),
925         /*linear_step_vars=*/ValueRange(), /*reduction_vars=*/ValueRange(),
926         /*reductions=*/nullptr, /*schedule_val=*/nullptr,
927         /*schedule_chunk_var=*/nullptr, /*schedule_modifier=*/nullptr,
928         /*simd_modifier=*/false, /*collapse_val=*/nullptr, /*nowait=*/false,
929         /*ordered_val=*/nullptr, /*order_val=*/nullptr, /*inclusive=*/false);
930   state.addAttributes(attributes);
931 }
932 
933 LogicalResult WsLoopOp::verify() {
934   return verifyReductionVarList(*this, reductions(), reduction_vars());
935 }
936 
937 //===----------------------------------------------------------------------===//
938 // Verifier for critical construct (2.17.1)
939 //===----------------------------------------------------------------------===//
940 
941 LogicalResult CriticalDeclareOp::verify() {
942   return verifySynchronizationHint(*this, hint_val());
943 }
944 
945 LogicalResult CriticalOp::verify() {
946   if (nameAttr()) {
947     SymbolRefAttr symbolRef = nameAttr();
948     auto decl = SymbolTable::lookupNearestSymbolFrom<CriticalDeclareOp>(
949         *this, symbolRef);
950     if (!decl) {
951       return emitOpError() << "expected symbol reference " << symbolRef
952                            << " to point to a critical declaration";
953     }
954   }
955 
956   return success();
957 }
958 
959 //===----------------------------------------------------------------------===//
960 // Verifier for ordered construct
961 //===----------------------------------------------------------------------===//
962 
963 LogicalResult OrderedOp::verify() {
964   auto container = (*this)->getParentOfType<WsLoopOp>();
965   if (!container || !container.ordered_valAttr() ||
966       container.ordered_valAttr().getInt() == 0)
967     return emitOpError() << "ordered depend directive must be closely "
968                          << "nested inside a worksharing-loop with ordered "
969                          << "clause with parameter present";
970 
971   if (container.ordered_valAttr().getInt() !=
972       (int64_t)num_loops_val().getValue())
973     return emitOpError() << "number of variables in depend clause does not "
974                          << "match number of iteration variables in the "
975                          << "doacross loop";
976 
977   return success();
978 }
979 
980 LogicalResult OrderedRegionOp::verify() {
981   // TODO: The code generation for ordered simd directive is not supported yet.
982   if (simd())
983     return failure();
984 
985   if (auto container = (*this)->getParentOfType<WsLoopOp>()) {
986     if (!container.ordered_valAttr() ||
987         container.ordered_valAttr().getInt() != 0)
988       return emitOpError() << "ordered region must be closely nested inside "
989                            << "a worksharing-loop region with an ordered "
990                            << "clause without parameter present";
991   }
992 
993   return success();
994 }
995 
996 //===----------------------------------------------------------------------===//
997 // Verifier for AtomicReadOp
998 //===----------------------------------------------------------------------===//
999 
1000 LogicalResult AtomicReadOp::verify() {
1001   if (auto mo = memory_order_val()) {
1002     if (*mo == ClauseMemoryOrderKind::acq_rel ||
1003         *mo == ClauseMemoryOrderKind::release) {
1004       return emitError(
1005           "memory-order must not be acq_rel or release for atomic reads");
1006     }
1007   }
1008   if (x() == v())
1009     return emitError(
1010         "read and write must not be to the same location for atomic reads");
1011   return verifySynchronizationHint(*this, hint_val());
1012 }
1013 
1014 //===----------------------------------------------------------------------===//
1015 // Verifier for AtomicWriteOp
1016 //===----------------------------------------------------------------------===//
1017 
1018 LogicalResult AtomicWriteOp::verify() {
1019   if (auto mo = memory_order_val()) {
1020     if (*mo == ClauseMemoryOrderKind::acq_rel ||
1021         *mo == ClauseMemoryOrderKind::acquire) {
1022       return emitError(
1023           "memory-order must not be acq_rel or acquire for atomic writes");
1024     }
1025   }
1026   return verifySynchronizationHint(*this, hint_val());
1027 }
1028 
1029 //===----------------------------------------------------------------------===//
1030 // Verifier for AtomicUpdateOp
1031 //===----------------------------------------------------------------------===//
1032 
1033 LogicalResult AtomicUpdateOp::verify() {
1034   if (auto mo = memory_order_val()) {
1035     if (*mo == ClauseMemoryOrderKind::acq_rel ||
1036         *mo == ClauseMemoryOrderKind::acquire) {
1037       return emitError(
1038           "memory-order must not be acq_rel or acquire for atomic updates");
1039     }
1040   }
1041 
1042   if (region().getNumArguments() != 1)
1043     return emitError("the region must accept exactly one argument");
1044 
1045   if (x().getType().cast<PointerLikeType>().getElementType() !=
1046       region().getArgument(0).getType()) {
1047     return emitError("the type of the operand must be a pointer type whose "
1048                      "element type is the same as that of the region argument");
1049   }
1050 
1051   YieldOp yieldOp = *region().getOps<YieldOp>().begin();
1052 
1053   if (yieldOp.results().size() != 1)
1054     return emitError("only updated value must be returned");
1055   if (yieldOp.results().front().getType() != region().getArgument(0).getType())
1056     return emitError("input and yielded value must have the same type");
1057   return success();
1058 }
1059 
1060 //===----------------------------------------------------------------------===//
1061 // Verifier for AtomicCaptureOp
1062 //===----------------------------------------------------------------------===//
1063 
1064 LogicalResult AtomicCaptureOp::verify() {
1065   Block::OpListType &ops = region().front().getOperations();
1066   if (ops.size() != 3)
1067     return emitError()
1068            << "expected three operations in omp.atomic.capture region (one "
1069               "terminator, and two atomic ops)";
1070   auto &firstOp = ops.front();
1071   auto &secondOp = *ops.getNextNode(firstOp);
1072   auto firstReadStmt = dyn_cast<AtomicReadOp>(firstOp);
1073   auto firstUpdateStmt = dyn_cast<AtomicUpdateOp>(firstOp);
1074   auto secondReadStmt = dyn_cast<AtomicReadOp>(secondOp);
1075   auto secondUpdateStmt = dyn_cast<AtomicUpdateOp>(secondOp);
1076   auto secondWriteStmt = dyn_cast<AtomicWriteOp>(secondOp);
1077 
1078   if (!((firstUpdateStmt && secondReadStmt) ||
1079         (firstReadStmt && secondUpdateStmt) ||
1080         (firstReadStmt && secondWriteStmt)))
1081     return ops.front().emitError()
1082            << "invalid sequence of operations in the capture region";
1083   if (firstUpdateStmt && secondReadStmt &&
1084       firstUpdateStmt.x() != secondReadStmt.x())
1085     return firstUpdateStmt.emitError()
1086            << "updated variable in omp.atomic.update must be captured in "
1087               "second operation";
1088   if (firstReadStmt && secondUpdateStmt &&
1089       firstReadStmt.x() != secondUpdateStmt.x())
1090     return firstReadStmt.emitError()
1091            << "captured variable in omp.atomic.read must be updated in second "
1092               "operation";
1093   if (firstReadStmt && secondWriteStmt &&
1094       firstReadStmt.x() != secondWriteStmt.address())
1095     return firstReadStmt.emitError()
1096            << "captured variable in omp.atomic.read must be updated in "
1097               "second operation";
1098   return success();
1099 }
1100 
1101 #define GET_ATTRDEF_CLASSES
1102 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
1103 
1104 #define GET_OP_CLASSES
1105 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
1106