1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the OpenMP dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
15 #include "mlir/IR/Attributes.h"
16 #include "mlir/IR/DialectImplementation.h"
17 #include "mlir/IR/OpImplementation.h"
18 #include "mlir/IR/OperationSupport.h"
19 
20 #include "llvm/ADT/BitVector.h"
21 #include "llvm/ADT/SmallString.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/ADT/StringSwitch.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include <cstddef>
27 
28 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
29 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
30 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
31 
32 using namespace mlir;
33 using namespace mlir::omp;
34 
35 namespace {
36 /// Model for pointer-like types that already provide a `getElementType` method.
37 template <typename T>
38 struct PointerLikeModel
39     : public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> {
40   Type getElementType(Type pointer) const {
41     return pointer.cast<T>().getElementType();
42   }
43 };
44 } // namespace
45 
46 void OpenMPDialect::initialize() {
47   addOperations<
48 #define GET_OP_LIST
49 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
50       >();
51   addAttributes<
52 #define GET_ATTRDEF_LIST
53 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
54       >();
55 
56   LLVM::LLVMPointerType::attachInterface<
57       PointerLikeModel<LLVM::LLVMPointerType>>(*getContext());
58   MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext());
59 }
60 
61 //===----------------------------------------------------------------------===//
62 // ParallelOp
63 //===----------------------------------------------------------------------===//
64 
65 void ParallelOp::build(OpBuilder &builder, OperationState &state,
66                        ArrayRef<NamedAttribute> attributes) {
67   ParallelOp::build(
68       builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
69       /*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(),
70       /*proc_bind_val=*/nullptr);
71   state.addAttributes(attributes);
72 }
73 
74 //===----------------------------------------------------------------------===//
75 // Parser and printer for Allocate Clause
76 //===----------------------------------------------------------------------===//
77 
78 /// Parse an allocate clause with allocators and a list of operands with types.
79 ///
80 /// allocate-operand-list :: = allocate-operand |
81 ///                            allocator-operand `,` allocate-operand-list
82 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
83 /// ssa-id-and-type ::= ssa-id `:` type
84 static ParseResult parseAllocateAndAllocator(
85     OpAsmParser &parser,
86     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operandsAllocate,
87     SmallVectorImpl<Type> &typesAllocate,
88     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operandsAllocator,
89     SmallVectorImpl<Type> &typesAllocator) {
90 
91   return parser.parseCommaSeparatedList([&]() -> ParseResult {
92     OpAsmParser::UnresolvedOperand operand;
93     Type type;
94     if (parser.parseOperand(operand) || parser.parseColonType(type))
95       return failure();
96     operandsAllocator.push_back(operand);
97     typesAllocator.push_back(type);
98     if (parser.parseArrow())
99       return failure();
100     if (parser.parseOperand(operand) || parser.parseColonType(type))
101       return failure();
102 
103     operandsAllocate.push_back(operand);
104     typesAllocate.push_back(type);
105     return success();
106   });
107 }
108 
109 /// Print allocate clause
110 static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op,
111                                       OperandRange varsAllocate,
112                                       TypeRange typesAllocate,
113                                       OperandRange varsAllocator,
114                                       TypeRange typesAllocator) {
115   for (unsigned i = 0; i < varsAllocate.size(); ++i) {
116     std::string separator = i == varsAllocate.size() - 1 ? "" : ", ";
117     p << varsAllocator[i] << " : " << typesAllocator[i] << " -> ";
118     p << varsAllocate[i] << " : " << typesAllocate[i] << separator;
119   }
120 }
121 
122 //===----------------------------------------------------------------------===//
123 // Parser and printer for a clause attribute (StringEnumAttr)
124 //===----------------------------------------------------------------------===//
125 
126 template <typename ClauseAttr>
127 static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) {
128   using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
129   StringRef enumStr;
130   SMLoc loc = parser.getCurrentLocation();
131   if (parser.parseKeyword(&enumStr))
132     return failure();
133   if (Optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
134     attr = ClauseAttr::get(parser.getContext(), *enumValue);
135     return success();
136   }
137   return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
138 }
139 
140 template <typename ClauseAttr>
141 void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
142   p << stringifyEnum(attr.getValue());
143 }
144 
145 LogicalResult ParallelOp::verify() {
146   if (allocate_vars().size() != allocators_vars().size())
147     return emitError(
148         "expected equal sizes for allocate and allocator variables");
149   return success();
150 }
151 
152 //===----------------------------------------------------------------------===//
153 // Parser and printer for Linear Clause
154 //===----------------------------------------------------------------------===//
155 
156 /// linear ::= `linear` `(` linear-list `)`
157 /// linear-list := linear-val | linear-val linear-list
158 /// linear-val := ssa-id-and-type `=` ssa-id-and-type
159 static ParseResult
160 parseLinearClause(OpAsmParser &parser,
161                   SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
162                   SmallVectorImpl<Type> &types,
163                   SmallVectorImpl<OpAsmParser::UnresolvedOperand> &stepVars) {
164   do {
165     OpAsmParser::UnresolvedOperand var;
166     Type type;
167     OpAsmParser::UnresolvedOperand stepVar;
168     if (parser.parseOperand(var) || parser.parseEqual() ||
169         parser.parseOperand(stepVar) || parser.parseColonType(type))
170       return failure();
171 
172     vars.push_back(var);
173     types.push_back(type);
174     stepVars.push_back(stepVar);
175   } while (succeeded(parser.parseOptionalComma()));
176   return success();
177 }
178 
179 /// Print Linear Clause
180 static void printLinearClause(OpAsmPrinter &p, Operation *op,
181                               ValueRange linearVars, TypeRange linearVarTypes,
182                               ValueRange linearStepVars) {
183   size_t linearVarsSize = linearVars.size();
184   for (unsigned i = 0; i < linearVarsSize; ++i) {
185     std::string separator = i == linearVarsSize - 1 ? "" : ", ";
186     p << linearVars[i];
187     if (linearStepVars.size() > i)
188       p << " = " << linearStepVars[i];
189     p << " : " << linearVars[i].getType() << separator;
190   }
191 }
192 
193 //===----------------------------------------------------------------------===//
194 // Parser, printer and verifier for Schedule Clause
195 //===----------------------------------------------------------------------===//
196 
197 static ParseResult
198 verifyScheduleModifiers(OpAsmParser &parser,
199                         SmallVectorImpl<SmallString<12>> &modifiers) {
200   if (modifiers.size() > 2)
201     return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)";
202   for (const auto &mod : modifiers) {
203     // Translate the string. If it has no value, then it was not a valid
204     // modifier!
205     auto symbol = symbolizeScheduleModifier(mod);
206     if (!symbol.hasValue())
207       return parser.emitError(parser.getNameLoc())
208              << " unknown modifier type: " << mod;
209   }
210 
211   // If we have one modifier that is "simd", then stick a "none" modiifer in
212   // index 0.
213   if (modifiers.size() == 1) {
214     if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
215       modifiers.push_back(modifiers[0]);
216       modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
217     }
218   } else if (modifiers.size() == 2) {
219     // If there are two modifier:
220     // First modifier should not be simd, second one should be simd
221     if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
222         symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
223       return parser.emitError(parser.getNameLoc())
224              << " incorrect modifier order";
225   }
226   return success();
227 }
228 
229 /// schedule ::= `schedule` `(` sched-list `)`
230 /// sched-list ::= sched-val | sched-val sched-list |
231 ///                sched-val `,` sched-modifier
232 /// sched-val ::= sched-with-chunk | sched-wo-chunk
233 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
234 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
235 /// sched-wo-chunk ::=  `auto` | `runtime`
236 /// sched-modifier ::=  sched-mod-val | sched-mod-val `,` sched-mod-val
237 /// sched-mod-val ::=  `monotonic` | `nonmonotonic` | `simd` | `none`
238 static ParseResult parseScheduleClause(
239     OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr,
240     ScheduleModifierAttr &schedule_modifier, UnitAttr &simdModifier,
241     Optional<OpAsmParser::UnresolvedOperand> &chunkSize, Type &chunkType) {
242   StringRef keyword;
243   if (parser.parseKeyword(&keyword))
244     return failure();
245   llvm::Optional<mlir::omp::ClauseScheduleKind> schedule =
246       symbolizeClauseScheduleKind(keyword);
247   if (!schedule)
248     return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
249 
250   scheduleAttr = ClauseScheduleKindAttr::get(parser.getContext(), *schedule);
251   switch (*schedule) {
252   case ClauseScheduleKind::Static:
253   case ClauseScheduleKind::Dynamic:
254   case ClauseScheduleKind::Guided:
255     if (succeeded(parser.parseOptionalEqual())) {
256       chunkSize = OpAsmParser::UnresolvedOperand{};
257       if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType))
258         return failure();
259     } else {
260       chunkSize = llvm::NoneType::None;
261     }
262     break;
263   case ClauseScheduleKind::Auto:
264   case ClauseScheduleKind::Runtime:
265     chunkSize = llvm::NoneType::None;
266   }
267 
268   // If there is a comma, we have one or more modifiers..
269   SmallVector<SmallString<12>> modifiers;
270   while (succeeded(parser.parseOptionalComma())) {
271     StringRef mod;
272     if (parser.parseKeyword(&mod))
273       return failure();
274     modifiers.push_back(mod);
275   }
276 
277   if (verifyScheduleModifiers(parser, modifiers))
278     return failure();
279 
280   if (!modifiers.empty()) {
281     SMLoc loc = parser.getCurrentLocation();
282     if (Optional<ScheduleModifier> mod =
283             symbolizeScheduleModifier(modifiers[0])) {
284       schedule_modifier = ScheduleModifierAttr::get(parser.getContext(), *mod);
285     } else {
286       return parser.emitError(loc, "invalid schedule modifier");
287     }
288     // Only SIMD attribute is allowed here!
289     if (modifiers.size() > 1) {
290       assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
291       simdModifier = UnitAttr::get(parser.getBuilder().getContext());
292     }
293   }
294 
295   return success();
296 }
297 
298 /// Print schedule clause
299 static void printScheduleClause(OpAsmPrinter &p, Operation *op,
300                                 ClauseScheduleKindAttr schedAttr,
301                                 ScheduleModifierAttr modifier, UnitAttr simd,
302                                 Value scheduleChunkVar,
303                                 Type scheduleChunkType) {
304   p << stringifyClauseScheduleKind(schedAttr.getValue());
305   if (scheduleChunkVar)
306     p << " = " << scheduleChunkVar << " : " << scheduleChunkVar.getType();
307   if (modifier)
308     p << ", " << stringifyScheduleModifier(modifier.getValue());
309   if (simd)
310     p << ", simd";
311 }
312 
313 //===----------------------------------------------------------------------===//
314 // Parser, printer and verifier for ReductionVarList
315 //===----------------------------------------------------------------------===//
316 
317 /// reduction-entry-list ::= reduction-entry
318 ///                        | reduction-entry-list `,` reduction-entry
319 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type
320 static ParseResult
321 parseReductionVarList(OpAsmParser &parser,
322                       SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
323                       SmallVectorImpl<Type> &types,
324                       ArrayAttr &redcuctionSymbols) {
325   SmallVector<SymbolRefAttr> reductionVec;
326   do {
327     if (parser.parseAttribute(reductionVec.emplace_back()) ||
328         parser.parseArrow() || parser.parseOperand(operands.emplace_back()) ||
329         parser.parseColonType(types.emplace_back()))
330       return failure();
331   } while (succeeded(parser.parseOptionalComma()));
332   SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
333   redcuctionSymbols = ArrayAttr::get(parser.getContext(), reductions);
334   return success();
335 }
336 
337 /// Print Reduction clause
338 static void printReductionVarList(OpAsmPrinter &p, Operation *op,
339                                   OperandRange reductionVars,
340                                   TypeRange reductionTypes,
341                                   Optional<ArrayAttr> reductions) {
342   for (unsigned i = 0, e = reductions->size(); i < e; ++i) {
343     if (i != 0)
344       p << ", ";
345     p << (*reductions)[i] << " -> " << reductionVars[i] << " : "
346       << reductionVars[i].getType();
347   }
348 }
349 
350 /// Verifies Reduction Clause
351 static LogicalResult verifyReductionVarList(Operation *op,
352                                             Optional<ArrayAttr> reductions,
353                                             OperandRange reductionVars) {
354   if (!reductionVars.empty()) {
355     if (!reductions || reductions->size() != reductionVars.size())
356       return op->emitOpError()
357              << "expected as many reduction symbol references "
358                 "as reduction variables";
359   } else {
360     if (reductions)
361       return op->emitOpError() << "unexpected reduction symbol references";
362     return success();
363   }
364 
365   // TODO: The followings should be done in
366   // SymbolUserOpInterface::verifySymbolUses.
367   DenseSet<Value> accumulators;
368   for (auto args : llvm::zip(reductionVars, *reductions)) {
369     Value accum = std::get<0>(args);
370 
371     if (!accumulators.insert(accum).second)
372       return op->emitOpError() << "accumulator variable used more than once";
373 
374     Type varType = accum.getType().cast<PointerLikeType>();
375     auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>();
376     auto decl =
377         SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef);
378     if (!decl)
379       return op->emitOpError() << "expected symbol reference " << symbolRef
380                                << " to point to a reduction declaration";
381 
382     if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
383       return op->emitOpError()
384              << "expected accumulator (" << varType
385              << ") to be the same type as reduction declaration ("
386              << decl.getAccumulatorType() << ")";
387   }
388 
389   return success();
390 }
391 
392 //===----------------------------------------------------------------------===//
393 // Parser, printer and verifier for Synchronization Hint (2.17.12)
394 //===----------------------------------------------------------------------===//
395 
396 /// Parses a Synchronization Hint clause. The value of hint is an integer
397 /// which is a combination of different hints from `omp_sync_hint_t`.
398 ///
399 /// hint-clause = `hint` `(` hint-value `)`
400 static ParseResult parseSynchronizationHint(OpAsmParser &parser,
401                                             IntegerAttr &hintAttr) {
402   StringRef hintKeyword;
403   int64_t hint = 0;
404   do {
405     if (failed(parser.parseKeyword(&hintKeyword)))
406       return failure();
407     if (hintKeyword == "uncontended")
408       hint |= 1;
409     else if (hintKeyword == "contended")
410       hint |= 2;
411     else if (hintKeyword == "nonspeculative")
412       hint |= 4;
413     else if (hintKeyword == "speculative")
414       hint |= 8;
415     else
416       return parser.emitError(parser.getCurrentLocation())
417              << hintKeyword << " is not a valid hint";
418   } while (succeeded(parser.parseOptionalComma()));
419   hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
420   return success();
421 }
422 
423 /// Prints a Synchronization Hint clause
424 static void printSynchronizationHint(OpAsmPrinter &p, Operation *op,
425                                      IntegerAttr hintAttr) {
426   int64_t hint = hintAttr.getInt();
427 
428   if (hint == 0)
429     return;
430 
431   // Helper function to get n-th bit from the right end of `value`
432   auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
433 
434   bool uncontended = bitn(hint, 0);
435   bool contended = bitn(hint, 1);
436   bool nonspeculative = bitn(hint, 2);
437   bool speculative = bitn(hint, 3);
438 
439   SmallVector<StringRef> hints;
440   if (uncontended)
441     hints.push_back("uncontended");
442   if (contended)
443     hints.push_back("contended");
444   if (nonspeculative)
445     hints.push_back("nonspeculative");
446   if (speculative)
447     hints.push_back("speculative");
448 
449   llvm::interleaveComma(hints, p);
450 }
451 
452 /// Verifies a synchronization hint clause
453 static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
454 
455   // Helper function to get n-th bit from the right end of `value`
456   auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
457 
458   bool uncontended = bitn(hint, 0);
459   bool contended = bitn(hint, 1);
460   bool nonspeculative = bitn(hint, 2);
461   bool speculative = bitn(hint, 3);
462 
463   if (uncontended && contended)
464     return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
465                                 "omp_sync_hint_contended cannot be combined";
466   if (nonspeculative && speculative)
467     return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
468                                 "omp_sync_hint_speculative cannot be combined.";
469   return success();
470 }
471 
472 //===----------------------------------------------------------------------===//
473 // Verifier for SectionsOp
474 //===----------------------------------------------------------------------===//
475 
476 LogicalResult SectionsOp::verify() {
477   if (allocate_vars().size() != allocators_vars().size())
478     return emitError(
479         "expected equal sizes for allocate and allocator variables");
480 
481   return verifyReductionVarList(*this, reductions(), reduction_vars());
482 }
483 
484 LogicalResult SectionsOp::verifyRegions() {
485   for (auto &inst : *region().begin()) {
486     if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
487       return emitOpError()
488              << "expected omp.section op or terminator op inside region";
489     }
490   }
491 
492   return success();
493 }
494 
495 LogicalResult SingleOp::verify() {
496   // Check for allocate clause restrictions
497   if (allocate_vars().size() != allocators_vars().size())
498     return emitError(
499         "expected equal sizes for allocate and allocator variables");
500 
501   return success();
502 }
503 
504 //===----------------------------------------------------------------------===//
505 // WsLoopOp
506 //===----------------------------------------------------------------------===//
507 
508 /// loop-control ::= `(` ssa-id-list `)` `:` type `=`  loop-bounds
509 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` inclusive? steps
510 /// steps := `step` `(`ssa-id-list`)`
511 ParseResult
512 parseWsLoopControl(OpAsmParser &parser, Region &region,
513                    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &lowerBound,
514                    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &upperBound,
515                    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &steps,
516                    SmallVectorImpl<Type> &loopVarTypes, UnitAttr &inclusive) {
517   // Parse an opening `(` followed by induction variables followed by `)`
518   SmallVector<OpAsmParser::UnresolvedOperand> ivs;
519   if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
520                                      OpAsmParser::Delimiter::Paren))
521     return failure();
522 
523   size_t numIVs = ivs.size();
524   Type loopVarType;
525   if (parser.parseColonType(loopVarType))
526     return failure();
527 
528   // Parse loop bounds.
529   if (parser.parseEqual() ||
530       parser.parseOperandList(lowerBound, numIVs,
531                               OpAsmParser::Delimiter::Paren))
532     return failure();
533   if (parser.parseKeyword("to") ||
534       parser.parseOperandList(upperBound, numIVs,
535                               OpAsmParser::Delimiter::Paren))
536     return failure();
537 
538   if (succeeded(parser.parseOptionalKeyword("inclusive"))) {
539     inclusive = UnitAttr::get(parser.getBuilder().getContext());
540   }
541 
542   // Parse step values.
543   if (parser.parseKeyword("step") ||
544       parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren))
545     return failure();
546 
547   // Now parse the body.
548   loopVarTypes = SmallVector<Type>(numIVs, loopVarType);
549   SmallVector<OpAsmParser::UnresolvedOperand> blockArgs(ivs);
550   if (parser.parseRegion(region, blockArgs, loopVarTypes))
551     return failure();
552   return success();
553 }
554 
555 void printWsLoopControl(OpAsmPrinter &p, Operation *op, Region &region,
556                         ValueRange lowerBound, ValueRange upperBound,
557                         ValueRange steps, TypeRange loopVarTypes,
558                         UnitAttr inclusive) {
559   auto args = region.front().getArguments();
560   p << " (" << args << ") : " << args[0].getType() << " = (" << lowerBound
561     << ") to (" << upperBound << ") ";
562   if (inclusive)
563     p << "inclusive ";
564   p << "step (" << steps << ") ";
565   p.printRegion(region, /*printEntryBlockArgs=*/false);
566 }
567 
568 //===----------------------------------------------------------------------===//
569 // SimdLoopOp
570 //===----------------------------------------------------------------------===//
571 /// Parses an OpenMP Simd construct [2.9.3.1]
572 ///
573 /// simdloop ::= `omp.simdloop` loop-control clause-list
574 /// loop-control ::= `(` ssa-id-list `)` `:` type `=`  loop-bounds
575 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` steps
576 /// steps := `step` `(`ssa-id-list`)`
577 /// clause-list ::= clause clause-list | empty
578 /// clause ::= TODO
579 ParseResult SimdLoopOp::parse(OpAsmParser &parser, OperationState &result) {
580   // Parse an opening `(` followed by induction variables followed by `)`
581   SmallVector<OpAsmParser::UnresolvedOperand> ivs;
582   if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
583                                      OpAsmParser::Delimiter::Paren))
584     return failure();
585   int numIVs = static_cast<int>(ivs.size());
586   Type loopVarType;
587   if (parser.parseColonType(loopVarType))
588     return failure();
589   // Parse loop bounds.
590   SmallVector<OpAsmParser::UnresolvedOperand> lower;
591   if (parser.parseEqual() ||
592       parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) ||
593       parser.resolveOperands(lower, loopVarType, result.operands))
594     return failure();
595   SmallVector<OpAsmParser::UnresolvedOperand> upper;
596   if (parser.parseKeyword("to") ||
597       parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) ||
598       parser.resolveOperands(upper, loopVarType, result.operands))
599     return failure();
600 
601   // Parse step values.
602   SmallVector<OpAsmParser::UnresolvedOperand> steps;
603   if (parser.parseKeyword("step") ||
604       parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) ||
605       parser.resolveOperands(steps, loopVarType, result.operands))
606     return failure();
607 
608   SmallVector<int> segments{numIVs, numIVs, numIVs};
609   // TODO: Add parseClauses() when we support clauses
610   result.addAttribute("operand_segment_sizes",
611                       parser.getBuilder().getI32VectorAttr(segments));
612 
613   // Now parse the body.
614   Region *body = result.addRegion();
615   SmallVector<Type> ivTypes(numIVs, loopVarType);
616   SmallVector<OpAsmParser::UnresolvedOperand> blockArgs(ivs);
617   if (parser.parseRegion(*body, blockArgs, ivTypes))
618     return failure();
619   return success();
620 }
621 
622 void SimdLoopOp::print(OpAsmPrinter &p) {
623   auto args = getRegion().front().getArguments();
624   p << " (" << args << ") : " << args[0].getType() << " = (" << lowerBound()
625     << ") to (" << upperBound() << ") ";
626   p << "step (" << step() << ") ";
627 
628   p.printRegion(region(), /*printEntryBlockArgs=*/false);
629 }
630 
631 //===----------------------------------------------------------------------===//
632 // Verifier for Simd construct [2.9.3.1]
633 //===----------------------------------------------------------------------===//
634 
635 LogicalResult SimdLoopOp::verify() {
636   if (this->lowerBound().empty()) {
637     return emitOpError() << "empty lowerbound for simd loop operation";
638   }
639   return success();
640 }
641 
642 //===----------------------------------------------------------------------===//
643 // ReductionOp
644 //===----------------------------------------------------------------------===//
645 
646 static ParseResult parseAtomicReductionRegion(OpAsmParser &parser,
647                                               Region &region) {
648   if (parser.parseOptionalKeyword("atomic"))
649     return success();
650   return parser.parseRegion(region);
651 }
652 
653 static void printAtomicReductionRegion(OpAsmPrinter &printer,
654                                        ReductionDeclareOp op, Region &region) {
655   if (region.empty())
656     return;
657   printer << "atomic ";
658   printer.printRegion(region);
659 }
660 
661 LogicalResult ReductionDeclareOp::verifyRegions() {
662   if (initializerRegion().empty())
663     return emitOpError() << "expects non-empty initializer region";
664   Block &initializerEntryBlock = initializerRegion().front();
665   if (initializerEntryBlock.getNumArguments() != 1 ||
666       initializerEntryBlock.getArgument(0).getType() != type()) {
667     return emitOpError() << "expects initializer region with one argument "
668                             "of the reduction type";
669   }
670 
671   for (YieldOp yieldOp : initializerRegion().getOps<YieldOp>()) {
672     if (yieldOp.results().size() != 1 ||
673         yieldOp.results().getTypes()[0] != type())
674       return emitOpError() << "expects initializer region to yield a value "
675                               "of the reduction type";
676   }
677 
678   if (reductionRegion().empty())
679     return emitOpError() << "expects non-empty reduction region";
680   Block &reductionEntryBlock = reductionRegion().front();
681   if (reductionEntryBlock.getNumArguments() != 2 ||
682       reductionEntryBlock.getArgumentTypes()[0] !=
683           reductionEntryBlock.getArgumentTypes()[1] ||
684       reductionEntryBlock.getArgumentTypes()[0] != type())
685     return emitOpError() << "expects reduction region with two arguments of "
686                             "the reduction type";
687   for (YieldOp yieldOp : reductionRegion().getOps<YieldOp>()) {
688     if (yieldOp.results().size() != 1 ||
689         yieldOp.results().getTypes()[0] != type())
690       return emitOpError() << "expects reduction region to yield a value "
691                               "of the reduction type";
692   }
693 
694   if (atomicReductionRegion().empty())
695     return success();
696 
697   Block &atomicReductionEntryBlock = atomicReductionRegion().front();
698   if (atomicReductionEntryBlock.getNumArguments() != 2 ||
699       atomicReductionEntryBlock.getArgumentTypes()[0] !=
700           atomicReductionEntryBlock.getArgumentTypes()[1])
701     return emitOpError() << "expects atomic reduction region with two "
702                             "arguments of the same type";
703   auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0]
704                      .dyn_cast<PointerLikeType>();
705   if (!ptrType || ptrType.getElementType() != type())
706     return emitOpError() << "expects atomic reduction region arguments to "
707                             "be accumulators containing the reduction type";
708   return success();
709 }
710 
711 LogicalResult ReductionOp::verify() {
712   // TODO: generalize this to an op interface when there is more than one op
713   // that supports reductions.
714   auto container = (*this)->getParentOfType<WsLoopOp>();
715   for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i)
716     if (container.reduction_vars()[i] == accumulator())
717       return success();
718 
719   return emitOpError() << "the accumulator is not used by the parent";
720 }
721 
722 //===----------------------------------------------------------------------===//
723 // WsLoopOp
724 //===----------------------------------------------------------------------===//
725 
726 void WsLoopOp::build(OpBuilder &builder, OperationState &state,
727                      ValueRange lowerBound, ValueRange upperBound,
728                      ValueRange step, ArrayRef<NamedAttribute> attributes) {
729   build(builder, state, lowerBound, upperBound, step,
730         /*linear_vars=*/ValueRange(),
731         /*linear_step_vars=*/ValueRange(), /*reduction_vars=*/ValueRange(),
732         /*reductions=*/nullptr, /*schedule_val=*/nullptr,
733         /*schedule_chunk_var=*/nullptr, /*schedule_modifier=*/nullptr,
734         /*simd_modifier=*/false, /*collapse_val=*/nullptr, /*nowait=*/false,
735         /*ordered_val=*/nullptr, /*order_val=*/nullptr, /*inclusive=*/false);
736   state.addAttributes(attributes);
737 }
738 
739 LogicalResult WsLoopOp::verify() {
740   return verifyReductionVarList(*this, reductions(), reduction_vars());
741 }
742 
743 //===----------------------------------------------------------------------===//
744 // Verifier for critical construct (2.17.1)
745 //===----------------------------------------------------------------------===//
746 
747 LogicalResult CriticalDeclareOp::verify() {
748   return verifySynchronizationHint(*this, hint_val());
749 }
750 
751 LogicalResult
752 CriticalOp::verifySymbolUses(SymbolTableCollection &symbol_table) {
753   if (nameAttr()) {
754     SymbolRefAttr symbolRef = nameAttr();
755     auto decl = symbol_table.lookupNearestSymbolFrom<CriticalDeclareOp>(
756         *this, symbolRef);
757     if (!decl) {
758       return emitOpError() << "expected symbol reference " << symbolRef
759                            << " to point to a critical declaration";
760     }
761   }
762 
763   return success();
764 }
765 
766 //===----------------------------------------------------------------------===//
767 // Verifier for ordered construct
768 //===----------------------------------------------------------------------===//
769 
770 LogicalResult OrderedOp::verify() {
771   auto container = (*this)->getParentOfType<WsLoopOp>();
772   if (!container || !container.ordered_valAttr() ||
773       container.ordered_valAttr().getInt() == 0)
774     return emitOpError() << "ordered depend directive must be closely "
775                          << "nested inside a worksharing-loop with ordered "
776                          << "clause with parameter present";
777 
778   if (container.ordered_valAttr().getInt() !=
779       (int64_t)num_loops_val().getValue())
780     return emitOpError() << "number of variables in depend clause does not "
781                          << "match number of iteration variables in the "
782                          << "doacross loop";
783 
784   return success();
785 }
786 
787 LogicalResult OrderedRegionOp::verify() {
788   // TODO: The code generation for ordered simd directive is not supported yet.
789   if (simd())
790     return failure();
791 
792   if (auto container = (*this)->getParentOfType<WsLoopOp>()) {
793     if (!container.ordered_valAttr() ||
794         container.ordered_valAttr().getInt() != 0)
795       return emitOpError() << "ordered region must be closely nested inside "
796                            << "a worksharing-loop region with an ordered "
797                            << "clause without parameter present";
798   }
799 
800   return success();
801 }
802 
803 //===----------------------------------------------------------------------===//
804 // Verifier for AtomicReadOp
805 //===----------------------------------------------------------------------===//
806 
807 LogicalResult AtomicReadOp::verify() {
808   if (auto mo = memory_order_val()) {
809     if (*mo == ClauseMemoryOrderKind::Acq_rel ||
810         *mo == ClauseMemoryOrderKind::Release) {
811       return emitError(
812           "memory-order must not be acq_rel or release for atomic reads");
813     }
814   }
815   if (x() == v())
816     return emitError(
817         "read and write must not be to the same location for atomic reads");
818   return verifySynchronizationHint(*this, hint_val());
819 }
820 
821 //===----------------------------------------------------------------------===//
822 // Verifier for AtomicWriteOp
823 //===----------------------------------------------------------------------===//
824 
825 LogicalResult AtomicWriteOp::verify() {
826   if (auto mo = memory_order_val()) {
827     if (*mo == ClauseMemoryOrderKind::Acq_rel ||
828         *mo == ClauseMemoryOrderKind::Acquire) {
829       return emitError(
830           "memory-order must not be acq_rel or acquire for atomic writes");
831     }
832   }
833   return verifySynchronizationHint(*this, hint_val());
834 }
835 
836 //===----------------------------------------------------------------------===//
837 // Verifier for AtomicUpdateOp
838 //===----------------------------------------------------------------------===//
839 
840 LogicalResult AtomicUpdateOp::verify() {
841   if (auto mo = memory_order_val()) {
842     if (*mo == ClauseMemoryOrderKind::Acq_rel ||
843         *mo == ClauseMemoryOrderKind::Acquire) {
844       return emitError(
845           "memory-order must not be acq_rel or acquire for atomic updates");
846     }
847   }
848 
849   if (x().getType().cast<PointerLikeType>().getElementType() !=
850       region().getArgument(0).getType()) {
851     return emitError("the type of the operand must be a pointer type whose "
852                      "element type is the same as that of the region argument");
853   }
854 
855   return success();
856 }
857 
858 LogicalResult AtomicUpdateOp::verifyRegions() {
859   if (region().getNumArguments() != 1)
860     return emitError("the region must accept exactly one argument");
861 
862   if (region().front().getOperations().size() < 2)
863     return emitError() << "the update region must have at least two operations "
864                           "(binop and terminator)";
865 
866   YieldOp yieldOp = *region().getOps<YieldOp>().begin();
867 
868   if (yieldOp.results().size() != 1)
869     return emitError("only updated value must be returned");
870   if (yieldOp.results().front().getType() != region().getArgument(0).getType())
871     return emitError("input and yielded value must have the same type");
872   return success();
873 }
874 
875 //===----------------------------------------------------------------------===//
876 // Verifier for AtomicCaptureOp
877 //===----------------------------------------------------------------------===//
878 
879 Operation *AtomicCaptureOp::getFirstOp() {
880   return &getRegion().front().getOperations().front();
881 }
882 
883 Operation *AtomicCaptureOp::getSecondOp() {
884   auto &ops = getRegion().front().getOperations();
885   return ops.getNextNode(ops.front());
886 }
887 
888 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
889   if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
890     return op;
891   return dyn_cast<AtomicReadOp>(getSecondOp());
892 }
893 
894 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
895   if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
896     return op;
897   return dyn_cast<AtomicWriteOp>(getSecondOp());
898 }
899 
900 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
901   if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
902     return op;
903   return dyn_cast<AtomicUpdateOp>(getSecondOp());
904 }
905 
906 LogicalResult AtomicCaptureOp::verifyRegions() {
907   Block::OpListType &ops = region().front().getOperations();
908   if (ops.size() != 3)
909     return emitError()
910            << "expected three operations in omp.atomic.capture region (one "
911               "terminator, and two atomic ops)";
912   auto &firstOp = ops.front();
913   auto &secondOp = *ops.getNextNode(firstOp);
914   auto firstReadStmt = dyn_cast<AtomicReadOp>(firstOp);
915   auto firstUpdateStmt = dyn_cast<AtomicUpdateOp>(firstOp);
916   auto secondReadStmt = dyn_cast<AtomicReadOp>(secondOp);
917   auto secondUpdateStmt = dyn_cast<AtomicUpdateOp>(secondOp);
918   auto secondWriteStmt = dyn_cast<AtomicWriteOp>(secondOp);
919 
920   if (!((firstUpdateStmt && secondReadStmt) ||
921         (firstReadStmt && secondUpdateStmt) ||
922         (firstReadStmt && secondWriteStmt)))
923     return ops.front().emitError()
924            << "invalid sequence of operations in the capture region";
925   if (firstUpdateStmt && secondReadStmt &&
926       firstUpdateStmt.x() != secondReadStmt.x())
927     return firstUpdateStmt.emitError()
928            << "updated variable in omp.atomic.update must be captured in "
929               "second operation";
930   if (firstReadStmt && secondUpdateStmt &&
931       firstReadStmt.x() != secondUpdateStmt.x())
932     return firstReadStmt.emitError()
933            << "captured variable in omp.atomic.read must be updated in second "
934               "operation";
935   if (firstReadStmt && secondWriteStmt &&
936       firstReadStmt.x() != secondWriteStmt.address())
937     return firstReadStmt.emitError()
938            << "captured variable in omp.atomic.read must be updated in "
939               "second operation";
940   return success();
941 }
942 
943 #define GET_ATTRDEF_CLASSES
944 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
945 
946 #define GET_OP_CLASSES
947 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
948