1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the OpenMP dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
15 #include "mlir/IR/Attributes.h"
16 #include "mlir/IR/DialectImplementation.h"
17 #include "mlir/IR/OpImplementation.h"
18 #include "mlir/IR/OperationSupport.h"
19 
20 #include "llvm/ADT/BitVector.h"
21 #include "llvm/ADT/SmallString.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/ADT/StringSwitch.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include <cstddef>
27 
28 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
29 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
30 #include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
31 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
32 
33 using namespace mlir;
34 using namespace mlir::omp;
35 
36 namespace {
37 /// Model for pointer-like types that already provide a `getElementType` method.
38 template <typename T>
39 struct PointerLikeModel
40     : public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> {
41   Type getElementType(Type pointer) const {
42     return pointer.cast<T>().getElementType();
43   }
44 };
45 } // namespace
46 
47 void OpenMPDialect::initialize() {
48   addOperations<
49 #define GET_OP_LIST
50 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
51       >();
52   addAttributes<
53 #define GET_ATTRDEF_LIST
54 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
55       >();
56 
57   LLVM::LLVMPointerType::attachInterface<
58       PointerLikeModel<LLVM::LLVMPointerType>>(*getContext());
59   MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext());
60 }
61 
62 //===----------------------------------------------------------------------===//
63 // Parser and printer for Allocate Clause
64 //===----------------------------------------------------------------------===//
65 
66 /// Parse an allocate clause with allocators and a list of operands with types.
67 ///
68 /// allocate-operand-list :: = allocate-operand |
69 ///                            allocator-operand `,` allocate-operand-list
70 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
71 /// ssa-id-and-type ::= ssa-id `:` type
72 static ParseResult parseAllocateAndAllocator(
73     OpAsmParser &parser,
74     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operandsAllocate,
75     SmallVectorImpl<Type> &typesAllocate,
76     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operandsAllocator,
77     SmallVectorImpl<Type> &typesAllocator) {
78 
79   return parser.parseCommaSeparatedList([&]() -> ParseResult {
80     OpAsmParser::UnresolvedOperand operand;
81     Type type;
82     if (parser.parseOperand(operand) || parser.parseColonType(type))
83       return failure();
84     operandsAllocator.push_back(operand);
85     typesAllocator.push_back(type);
86     if (parser.parseArrow())
87       return failure();
88     if (parser.parseOperand(operand) || parser.parseColonType(type))
89       return failure();
90 
91     operandsAllocate.push_back(operand);
92     typesAllocate.push_back(type);
93     return success();
94   });
95 }
96 
97 /// Print allocate clause
98 static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op,
99                                       OperandRange varsAllocate,
100                                       TypeRange typesAllocate,
101                                       OperandRange varsAllocator,
102                                       TypeRange typesAllocator) {
103   for (unsigned i = 0; i < varsAllocate.size(); ++i) {
104     std::string separator = i == varsAllocate.size() - 1 ? "" : ", ";
105     p << varsAllocator[i] << " : " << typesAllocator[i] << " -> ";
106     p << varsAllocate[i] << " : " << typesAllocate[i] << separator;
107   }
108 }
109 
110 //===----------------------------------------------------------------------===//
111 // Parser and printer for a clause attribute (StringEnumAttr)
112 //===----------------------------------------------------------------------===//
113 
114 template <typename ClauseAttr>
115 static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) {
116   using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
117   StringRef enumStr;
118   SMLoc loc = parser.getCurrentLocation();
119   if (parser.parseKeyword(&enumStr))
120     return failure();
121   if (Optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
122     attr = ClauseAttr::get(parser.getContext(), *enumValue);
123     return success();
124   }
125   return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
126 }
127 
128 template <typename ClauseAttr>
129 void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
130   p << stringifyEnum(attr.getValue());
131 }
132 
133 //===----------------------------------------------------------------------===//
134 // Parser and printer for Linear Clause
135 //===----------------------------------------------------------------------===//
136 
137 /// linear ::= `linear` `(` linear-list `)`
138 /// linear-list := linear-val | linear-val linear-list
139 /// linear-val := ssa-id-and-type `=` ssa-id-and-type
140 static ParseResult
141 parseLinearClause(OpAsmParser &parser,
142                   SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
143                   SmallVectorImpl<Type> &types,
144                   SmallVectorImpl<OpAsmParser::UnresolvedOperand> &stepVars) {
145   do {
146     OpAsmParser::UnresolvedOperand var;
147     Type type;
148     OpAsmParser::UnresolvedOperand stepVar;
149     if (parser.parseOperand(var) || parser.parseEqual() ||
150         parser.parseOperand(stepVar) || parser.parseColonType(type))
151       return failure();
152 
153     vars.push_back(var);
154     types.push_back(type);
155     stepVars.push_back(stepVar);
156   } while (succeeded(parser.parseOptionalComma()));
157   return success();
158 }
159 
160 /// Print Linear Clause
161 static void printLinearClause(OpAsmPrinter &p, Operation *op,
162                               ValueRange linearVars, TypeRange linearVarTypes,
163                               ValueRange linearStepVars) {
164   size_t linearVarsSize = linearVars.size();
165   for (unsigned i = 0; i < linearVarsSize; ++i) {
166     std::string separator = i == linearVarsSize - 1 ? "" : ", ";
167     p << linearVars[i];
168     if (linearStepVars.size() > i)
169       p << " = " << linearStepVars[i];
170     p << " : " << linearVars[i].getType() << separator;
171   }
172 }
173 
174 //===----------------------------------------------------------------------===//
175 // Parser, printer and verifier for Schedule Clause
176 //===----------------------------------------------------------------------===//
177 
178 static ParseResult
179 verifyScheduleModifiers(OpAsmParser &parser,
180                         SmallVectorImpl<SmallString<12>> &modifiers) {
181   if (modifiers.size() > 2)
182     return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)";
183   for (const auto &mod : modifiers) {
184     // Translate the string. If it has no value, then it was not a valid
185     // modifier!
186     auto symbol = symbolizeScheduleModifier(mod);
187     if (!symbol.hasValue())
188       return parser.emitError(parser.getNameLoc())
189              << " unknown modifier type: " << mod;
190   }
191 
192   // If we have one modifier that is "simd", then stick a "none" modiifer in
193   // index 0.
194   if (modifiers.size() == 1) {
195     if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
196       modifiers.push_back(modifiers[0]);
197       modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
198     }
199   } else if (modifiers.size() == 2) {
200     // If there are two modifier:
201     // First modifier should not be simd, second one should be simd
202     if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
203         symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
204       return parser.emitError(parser.getNameLoc())
205              << " incorrect modifier order";
206   }
207   return success();
208 }
209 
210 /// schedule ::= `schedule` `(` sched-list `)`
211 /// sched-list ::= sched-val | sched-val sched-list |
212 ///                sched-val `,` sched-modifier
213 /// sched-val ::= sched-with-chunk | sched-wo-chunk
214 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
215 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
216 /// sched-wo-chunk ::=  `auto` | `runtime`
217 /// sched-modifier ::=  sched-mod-val | sched-mod-val `,` sched-mod-val
218 /// sched-mod-val ::=  `monotonic` | `nonmonotonic` | `simd` | `none`
219 static ParseResult parseScheduleClause(
220     OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr,
221     ScheduleModifierAttr &scheduleModifier, UnitAttr &simdModifier,
222     Optional<OpAsmParser::UnresolvedOperand> &chunkSize, Type &chunkType) {
223   StringRef keyword;
224   if (parser.parseKeyword(&keyword))
225     return failure();
226   llvm::Optional<mlir::omp::ClauseScheduleKind> schedule =
227       symbolizeClauseScheduleKind(keyword);
228   if (!schedule)
229     return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
230 
231   scheduleAttr = ClauseScheduleKindAttr::get(parser.getContext(), *schedule);
232   switch (*schedule) {
233   case ClauseScheduleKind::Static:
234   case ClauseScheduleKind::Dynamic:
235   case ClauseScheduleKind::Guided:
236     if (succeeded(parser.parseOptionalEqual())) {
237       chunkSize = OpAsmParser::UnresolvedOperand{};
238       if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType))
239         return failure();
240     } else {
241       chunkSize = llvm::NoneType::None;
242     }
243     break;
244   case ClauseScheduleKind::Auto:
245   case ClauseScheduleKind::Runtime:
246     chunkSize = llvm::NoneType::None;
247   }
248 
249   // If there is a comma, we have one or more modifiers..
250   SmallVector<SmallString<12>> modifiers;
251   while (succeeded(parser.parseOptionalComma())) {
252     StringRef mod;
253     if (parser.parseKeyword(&mod))
254       return failure();
255     modifiers.push_back(mod);
256   }
257 
258   if (verifyScheduleModifiers(parser, modifiers))
259     return failure();
260 
261   if (!modifiers.empty()) {
262     SMLoc loc = parser.getCurrentLocation();
263     if (Optional<ScheduleModifier> mod =
264             symbolizeScheduleModifier(modifiers[0])) {
265       scheduleModifier = ScheduleModifierAttr::get(parser.getContext(), *mod);
266     } else {
267       return parser.emitError(loc, "invalid schedule modifier");
268     }
269     // Only SIMD attribute is allowed here!
270     if (modifiers.size() > 1) {
271       assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
272       simdModifier = UnitAttr::get(parser.getBuilder().getContext());
273     }
274   }
275 
276   return success();
277 }
278 
279 /// Print schedule clause
280 static void printScheduleClause(OpAsmPrinter &p, Operation *op,
281                                 ClauseScheduleKindAttr schedAttr,
282                                 ScheduleModifierAttr modifier, UnitAttr simd,
283                                 Value scheduleChunkVar,
284                                 Type scheduleChunkType) {
285   p << stringifyClauseScheduleKind(schedAttr.getValue());
286   if (scheduleChunkVar)
287     p << " = " << scheduleChunkVar << " : " << scheduleChunkVar.getType();
288   if (modifier)
289     p << ", " << stringifyScheduleModifier(modifier.getValue());
290   if (simd)
291     p << ", simd";
292 }
293 
294 //===----------------------------------------------------------------------===//
295 // Parser, printer and verifier for ReductionVarList
296 //===----------------------------------------------------------------------===//
297 
298 /// reduction-entry-list ::= reduction-entry
299 ///                        | reduction-entry-list `,` reduction-entry
300 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type
301 static ParseResult
302 parseReductionVarList(OpAsmParser &parser,
303                       SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
304                       SmallVectorImpl<Type> &types,
305                       ArrayAttr &redcuctionSymbols) {
306   SmallVector<SymbolRefAttr> reductionVec;
307   do {
308     if (parser.parseAttribute(reductionVec.emplace_back()) ||
309         parser.parseArrow() || parser.parseOperand(operands.emplace_back()) ||
310         parser.parseColonType(types.emplace_back()))
311       return failure();
312   } while (succeeded(parser.parseOptionalComma()));
313   SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
314   redcuctionSymbols = ArrayAttr::get(parser.getContext(), reductions);
315   return success();
316 }
317 
318 /// Print Reduction clause
319 static void printReductionVarList(OpAsmPrinter &p, Operation *op,
320                                   OperandRange reductionVars,
321                                   TypeRange reductionTypes,
322                                   Optional<ArrayAttr> reductions) {
323   for (unsigned i = 0, e = reductions->size(); i < e; ++i) {
324     if (i != 0)
325       p << ", ";
326     p << (*reductions)[i] << " -> " << reductionVars[i] << " : "
327       << reductionVars[i].getType();
328   }
329 }
330 
331 /// Verifies Reduction Clause
332 static LogicalResult verifyReductionVarList(Operation *op,
333                                             Optional<ArrayAttr> reductions,
334                                             OperandRange reductionVars) {
335   if (!reductionVars.empty()) {
336     if (!reductions || reductions->size() != reductionVars.size())
337       return op->emitOpError()
338              << "expected as many reduction symbol references "
339                 "as reduction variables";
340   } else {
341     if (reductions)
342       return op->emitOpError() << "unexpected reduction symbol references";
343     return success();
344   }
345 
346   // TODO: The followings should be done in
347   // SymbolUserOpInterface::verifySymbolUses.
348   DenseSet<Value> accumulators;
349   for (auto args : llvm::zip(reductionVars, *reductions)) {
350     Value accum = std::get<0>(args);
351 
352     if (!accumulators.insert(accum).second)
353       return op->emitOpError() << "accumulator variable used more than once";
354 
355     Type varType = accum.getType().cast<PointerLikeType>();
356     auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>();
357     auto decl =
358         SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef);
359     if (!decl)
360       return op->emitOpError() << "expected symbol reference " << symbolRef
361                                << " to point to a reduction declaration";
362 
363     if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
364       return op->emitOpError()
365              << "expected accumulator (" << varType
366              << ") to be the same type as reduction declaration ("
367              << decl.getAccumulatorType() << ")";
368   }
369 
370   return success();
371 }
372 
373 //===----------------------------------------------------------------------===//
374 // Parser, printer and verifier for Synchronization Hint (2.17.12)
375 //===----------------------------------------------------------------------===//
376 
377 /// Parses a Synchronization Hint clause. The value of hint is an integer
378 /// which is a combination of different hints from `omp_sync_hint_t`.
379 ///
380 /// hint-clause = `hint` `(` hint-value `)`
381 static ParseResult parseSynchronizationHint(OpAsmParser &parser,
382                                             IntegerAttr &hintAttr) {
383   StringRef hintKeyword;
384   int64_t hint = 0;
385   if (succeeded(parser.parseOptionalKeyword("none"))) {
386     hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
387     return success();
388   }
389   do {
390     if (failed(parser.parseKeyword(&hintKeyword)))
391       return failure();
392     if (hintKeyword == "uncontended")
393       hint |= 1;
394     else if (hintKeyword == "contended")
395       hint |= 2;
396     else if (hintKeyword == "nonspeculative")
397       hint |= 4;
398     else if (hintKeyword == "speculative")
399       hint |= 8;
400     else
401       return parser.emitError(parser.getCurrentLocation())
402              << hintKeyword << " is not a valid hint";
403   } while (succeeded(parser.parseOptionalComma()));
404   hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
405   return success();
406 }
407 
408 /// Prints a Synchronization Hint clause
409 static void printSynchronizationHint(OpAsmPrinter &p, Operation *op,
410                                      IntegerAttr hintAttr) {
411   int64_t hint = hintAttr.getInt();
412 
413   if (hint == 0) {
414     p << "none";
415     return;
416   }
417 
418   // Helper function to get n-th bit from the right end of `value`
419   auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
420 
421   bool uncontended = bitn(hint, 0);
422   bool contended = bitn(hint, 1);
423   bool nonspeculative = bitn(hint, 2);
424   bool speculative = bitn(hint, 3);
425 
426   SmallVector<StringRef> hints;
427   if (uncontended)
428     hints.push_back("uncontended");
429   if (contended)
430     hints.push_back("contended");
431   if (nonspeculative)
432     hints.push_back("nonspeculative");
433   if (speculative)
434     hints.push_back("speculative");
435 
436   llvm::interleaveComma(hints, p);
437 }
438 
439 /// Verifies a synchronization hint clause
440 static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
441 
442   // Helper function to get n-th bit from the right end of `value`
443   auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
444 
445   bool uncontended = bitn(hint, 0);
446   bool contended = bitn(hint, 1);
447   bool nonspeculative = bitn(hint, 2);
448   bool speculative = bitn(hint, 3);
449 
450   if (uncontended && contended)
451     return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
452                                 "omp_sync_hint_contended cannot be combined";
453   if (nonspeculative && speculative)
454     return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
455                                 "omp_sync_hint_speculative cannot be combined.";
456   return success();
457 }
458 
459 //===----------------------------------------------------------------------===//
460 // ParallelOp
461 //===----------------------------------------------------------------------===//
462 
463 void ParallelOp::build(OpBuilder &builder, OperationState &state,
464                        ArrayRef<NamedAttribute> attributes) {
465   ParallelOp::build(
466       builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
467       /*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(),
468       /*reduction_vars=*/ValueRange(), /*reductions=*/nullptr,
469       /*proc_bind_val=*/nullptr);
470   state.addAttributes(attributes);
471 }
472 
473 LogicalResult ParallelOp::verify() {
474   if (allocate_vars().size() != allocators_vars().size())
475     return emitError(
476         "expected equal sizes for allocate and allocator variables");
477   return verifyReductionVarList(*this, reductions(), reduction_vars());
478 }
479 
480 //===----------------------------------------------------------------------===//
481 // Verifier for SectionsOp
482 //===----------------------------------------------------------------------===//
483 
484 LogicalResult SectionsOp::verify() {
485   if (allocate_vars().size() != allocators_vars().size())
486     return emitError(
487         "expected equal sizes for allocate and allocator variables");
488 
489   return verifyReductionVarList(*this, reductions(), reduction_vars());
490 }
491 
492 LogicalResult SectionsOp::verifyRegions() {
493   for (auto &inst : *region().begin()) {
494     if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
495       return emitOpError()
496              << "expected omp.section op or terminator op inside region";
497     }
498   }
499 
500   return success();
501 }
502 
503 LogicalResult SingleOp::verify() {
504   // Check for allocate clause restrictions
505   if (allocate_vars().size() != allocators_vars().size())
506     return emitError(
507         "expected equal sizes for allocate and allocator variables");
508 
509   return success();
510 }
511 
512 //===----------------------------------------------------------------------===//
513 // WsLoopOp
514 //===----------------------------------------------------------------------===//
515 
516 /// loop-control ::= `(` ssa-id-list `)` `:` type `=`  loop-bounds
517 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` inclusive? steps
518 /// steps := `step` `(`ssa-id-list`)`
519 ParseResult
520 parseWsLoopControl(OpAsmParser &parser, Region &region,
521                    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &lowerBound,
522                    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &upperBound,
523                    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &steps,
524                    SmallVectorImpl<Type> &loopVarTypes, UnitAttr &inclusive) {
525   // Parse an opening `(` followed by induction variables followed by `)`
526   SmallVector<OpAsmParser::UnresolvedOperand> ivs;
527   if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
528                                      OpAsmParser::Delimiter::Paren))
529     return failure();
530 
531   size_t numIVs = ivs.size();
532   Type loopVarType;
533   if (parser.parseColonType(loopVarType))
534     return failure();
535 
536   // Parse loop bounds.
537   if (parser.parseEqual() ||
538       parser.parseOperandList(lowerBound, numIVs,
539                               OpAsmParser::Delimiter::Paren))
540     return failure();
541   if (parser.parseKeyword("to") ||
542       parser.parseOperandList(upperBound, numIVs,
543                               OpAsmParser::Delimiter::Paren))
544     return failure();
545 
546   if (succeeded(parser.parseOptionalKeyword("inclusive"))) {
547     inclusive = UnitAttr::get(parser.getBuilder().getContext());
548   }
549 
550   // Parse step values.
551   if (parser.parseKeyword("step") ||
552       parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren))
553     return failure();
554 
555   // Now parse the body.
556   loopVarTypes = SmallVector<Type>(numIVs, loopVarType);
557   SmallVector<OpAsmParser::UnresolvedOperand> blockArgs(ivs);
558   if (parser.parseRegion(region, blockArgs, loopVarTypes))
559     return failure();
560   return success();
561 }
562 
563 void printWsLoopControl(OpAsmPrinter &p, Operation *op, Region &region,
564                         ValueRange lowerBound, ValueRange upperBound,
565                         ValueRange steps, TypeRange loopVarTypes,
566                         UnitAttr inclusive) {
567   auto args = region.front().getArguments();
568   p << " (" << args << ") : " << args[0].getType() << " = (" << lowerBound
569     << ") to (" << upperBound << ") ";
570   if (inclusive)
571     p << "inclusive ";
572   p << "step (" << steps << ") ";
573   p.printRegion(region, /*printEntryBlockArgs=*/false);
574 }
575 
576 //===----------------------------------------------------------------------===//
577 // SimdLoopOp
578 //===----------------------------------------------------------------------===//
579 /// Parses an OpenMP Simd construct [2.9.3.1]
580 ///
581 /// simdloop ::= `omp.simdloop` loop-control clause-list
582 /// loop-control ::= `(` ssa-id-list `)` `:` type `=`  loop-bounds
583 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` steps
584 /// steps := `step` `(`ssa-id-list`)`
585 /// clause-list ::= clause clause-list | empty
586 /// clause ::= TODO
587 ParseResult SimdLoopOp::parse(OpAsmParser &parser, OperationState &result) {
588   // Parse an opening `(` followed by induction variables followed by `)`
589   SmallVector<OpAsmParser::UnresolvedOperand> ivs;
590   if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
591                                      OpAsmParser::Delimiter::Paren))
592     return failure();
593   int numIVs = static_cast<int>(ivs.size());
594   Type loopVarType;
595   if (parser.parseColonType(loopVarType))
596     return failure();
597   // Parse loop bounds.
598   SmallVector<OpAsmParser::UnresolvedOperand> lower;
599   if (parser.parseEqual() ||
600       parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) ||
601       parser.resolveOperands(lower, loopVarType, result.operands))
602     return failure();
603   SmallVector<OpAsmParser::UnresolvedOperand> upper;
604   if (parser.parseKeyword("to") ||
605       parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) ||
606       parser.resolveOperands(upper, loopVarType, result.operands))
607     return failure();
608 
609   // Parse step values.
610   SmallVector<OpAsmParser::UnresolvedOperand> steps;
611   if (parser.parseKeyword("step") ||
612       parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) ||
613       parser.resolveOperands(steps, loopVarType, result.operands))
614     return failure();
615 
616   SmallVector<int> segments{numIVs, numIVs, numIVs};
617   // TODO: Add parseClauses() when we support clauses
618   result.addAttribute("operand_segment_sizes",
619                       parser.getBuilder().getI32VectorAttr(segments));
620 
621   // Now parse the body.
622   Region *body = result.addRegion();
623   SmallVector<Type> ivTypes(numIVs, loopVarType);
624   SmallVector<OpAsmParser::UnresolvedOperand> blockArgs(ivs);
625   if (parser.parseRegion(*body, blockArgs, ivTypes))
626     return failure();
627   return success();
628 }
629 
630 void SimdLoopOp::print(OpAsmPrinter &p) {
631   auto args = getRegion().front().getArguments();
632   p << " (" << args << ") : " << args[0].getType() << " = (" << lowerBound()
633     << ") to (" << upperBound() << ") ";
634   p << "step (" << step() << ") ";
635 
636   p.printRegion(region(), /*printEntryBlockArgs=*/false);
637 }
638 
639 //===----------------------------------------------------------------------===//
640 // Verifier for Simd construct [2.9.3.1]
641 //===----------------------------------------------------------------------===//
642 
643 LogicalResult SimdLoopOp::verify() {
644   if (this->lowerBound().empty()) {
645     return emitOpError() << "empty lowerbound for simd loop operation";
646   }
647   return success();
648 }
649 
650 //===----------------------------------------------------------------------===//
651 // ReductionOp
652 //===----------------------------------------------------------------------===//
653 
654 static ParseResult parseAtomicReductionRegion(OpAsmParser &parser,
655                                               Region &region) {
656   if (parser.parseOptionalKeyword("atomic"))
657     return success();
658   return parser.parseRegion(region);
659 }
660 
661 static void printAtomicReductionRegion(OpAsmPrinter &printer,
662                                        ReductionDeclareOp op, Region &region) {
663   if (region.empty())
664     return;
665   printer << "atomic ";
666   printer.printRegion(region);
667 }
668 
669 LogicalResult ReductionDeclareOp::verifyRegions() {
670   if (initializerRegion().empty())
671     return emitOpError() << "expects non-empty initializer region";
672   Block &initializerEntryBlock = initializerRegion().front();
673   if (initializerEntryBlock.getNumArguments() != 1 ||
674       initializerEntryBlock.getArgument(0).getType() != type()) {
675     return emitOpError() << "expects initializer region with one argument "
676                             "of the reduction type";
677   }
678 
679   for (YieldOp yieldOp : initializerRegion().getOps<YieldOp>()) {
680     if (yieldOp.results().size() != 1 ||
681         yieldOp.results().getTypes()[0] != type())
682       return emitOpError() << "expects initializer region to yield a value "
683                               "of the reduction type";
684   }
685 
686   if (reductionRegion().empty())
687     return emitOpError() << "expects non-empty reduction region";
688   Block &reductionEntryBlock = reductionRegion().front();
689   if (reductionEntryBlock.getNumArguments() != 2 ||
690       reductionEntryBlock.getArgumentTypes()[0] !=
691           reductionEntryBlock.getArgumentTypes()[1] ||
692       reductionEntryBlock.getArgumentTypes()[0] != type())
693     return emitOpError() << "expects reduction region with two arguments of "
694                             "the reduction type";
695   for (YieldOp yieldOp : reductionRegion().getOps<YieldOp>()) {
696     if (yieldOp.results().size() != 1 ||
697         yieldOp.results().getTypes()[0] != type())
698       return emitOpError() << "expects reduction region to yield a value "
699                               "of the reduction type";
700   }
701 
702   if (atomicReductionRegion().empty())
703     return success();
704 
705   Block &atomicReductionEntryBlock = atomicReductionRegion().front();
706   if (atomicReductionEntryBlock.getNumArguments() != 2 ||
707       atomicReductionEntryBlock.getArgumentTypes()[0] !=
708           atomicReductionEntryBlock.getArgumentTypes()[1])
709     return emitOpError() << "expects atomic reduction region with two "
710                             "arguments of the same type";
711   auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0]
712                      .dyn_cast<PointerLikeType>();
713   if (!ptrType || ptrType.getElementType() != type())
714     return emitOpError() << "expects atomic reduction region arguments to "
715                             "be accumulators containing the reduction type";
716   return success();
717 }
718 
719 LogicalResult ReductionOp::verify() {
720   auto *op = (*this)->getParentWithTrait<ReductionClauseInterface::Trait>();
721   if (!op)
722     return emitOpError() << "must be used within an operation supporting "
723                             "reduction clause interface";
724   while (op) {
725     for (const auto &var :
726          cast<ReductionClauseInterface>(op).getReductionVars())
727       if (var == accumulator())
728         return success();
729     op = op->getParentWithTrait<ReductionClauseInterface::Trait>();
730   }
731   return emitOpError() << "the accumulator is not used by the parent";
732 }
733 
734 //===----------------------------------------------------------------------===//
735 // TaskOp
736 //===----------------------------------------------------------------------===//
737 LogicalResult TaskOp::verify() {
738   return verifyReductionVarList(*this, in_reductions(), in_reduction_vars());
739 }
740 
741 //===----------------------------------------------------------------------===//
742 // WsLoopOp
743 //===----------------------------------------------------------------------===//
744 
745 void WsLoopOp::build(OpBuilder &builder, OperationState &state,
746                      ValueRange lowerBound, ValueRange upperBound,
747                      ValueRange step, ArrayRef<NamedAttribute> attributes) {
748   build(builder, state, lowerBound, upperBound, step,
749         /*linear_vars=*/ValueRange(),
750         /*linear_step_vars=*/ValueRange(), /*reduction_vars=*/ValueRange(),
751         /*reductions=*/nullptr, /*schedule_val=*/nullptr,
752         /*schedule_chunk_var=*/nullptr, /*schedule_modifier=*/nullptr,
753         /*simd_modifier=*/false, /*collapse_val=*/nullptr, /*nowait=*/false,
754         /*ordered_val=*/nullptr, /*order_val=*/nullptr, /*inclusive=*/false);
755   state.addAttributes(attributes);
756 }
757 
758 LogicalResult WsLoopOp::verify() {
759   return verifyReductionVarList(*this, reductions(), reduction_vars());
760 }
761 
762 //===----------------------------------------------------------------------===//
763 // Verifier for critical construct (2.17.1)
764 //===----------------------------------------------------------------------===//
765 
766 LogicalResult CriticalDeclareOp::verify() {
767   return verifySynchronizationHint(*this, hint_val());
768 }
769 
770 LogicalResult CriticalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
771   if (nameAttr()) {
772     SymbolRefAttr symbolRef = nameAttr();
773     auto decl = symbolTable.lookupNearestSymbolFrom<CriticalDeclareOp>(
774         *this, symbolRef);
775     if (!decl) {
776       return emitOpError() << "expected symbol reference " << symbolRef
777                            << " to point to a critical declaration";
778     }
779   }
780 
781   return success();
782 }
783 
784 //===----------------------------------------------------------------------===//
785 // Verifier for ordered construct
786 //===----------------------------------------------------------------------===//
787 
788 LogicalResult OrderedOp::verify() {
789   auto container = (*this)->getParentOfType<WsLoopOp>();
790   if (!container || !container.ordered_valAttr() ||
791       container.ordered_valAttr().getInt() == 0)
792     return emitOpError() << "ordered depend directive must be closely "
793                          << "nested inside a worksharing-loop with ordered "
794                          << "clause with parameter present";
795 
796   if (container.ordered_valAttr().getInt() !=
797       (int64_t)num_loops_val().getValue())
798     return emitOpError() << "number of variables in depend clause does not "
799                          << "match number of iteration variables in the "
800                          << "doacross loop";
801 
802   return success();
803 }
804 
805 LogicalResult OrderedRegionOp::verify() {
806   // TODO: The code generation for ordered simd directive is not supported yet.
807   if (simd())
808     return failure();
809 
810   if (auto container = (*this)->getParentOfType<WsLoopOp>()) {
811     if (!container.ordered_valAttr() ||
812         container.ordered_valAttr().getInt() != 0)
813       return emitOpError() << "ordered region must be closely nested inside "
814                            << "a worksharing-loop region with an ordered "
815                            << "clause without parameter present";
816   }
817 
818   return success();
819 }
820 
821 //===----------------------------------------------------------------------===//
822 // Verifier for AtomicReadOp
823 //===----------------------------------------------------------------------===//
824 
825 LogicalResult AtomicReadOp::verify() {
826   if (auto mo = memory_order_val()) {
827     if (*mo == ClauseMemoryOrderKind::Acq_rel ||
828         *mo == ClauseMemoryOrderKind::Release) {
829       return emitError(
830           "memory-order must not be acq_rel or release for atomic reads");
831     }
832   }
833   if (x() == v())
834     return emitError(
835         "read and write must not be to the same location for atomic reads");
836   return verifySynchronizationHint(*this, hint_val());
837 }
838 
839 //===----------------------------------------------------------------------===//
840 // Verifier for AtomicWriteOp
841 //===----------------------------------------------------------------------===//
842 
843 LogicalResult AtomicWriteOp::verify() {
844   if (auto mo = memory_order_val()) {
845     if (*mo == ClauseMemoryOrderKind::Acq_rel ||
846         *mo == ClauseMemoryOrderKind::Acquire) {
847       return emitError(
848           "memory-order must not be acq_rel or acquire for atomic writes");
849     }
850   }
851   return verifySynchronizationHint(*this, hint_val());
852 }
853 
854 //===----------------------------------------------------------------------===//
855 // Verifier for AtomicUpdateOp
856 //===----------------------------------------------------------------------===//
857 
858 LogicalResult AtomicUpdateOp::verify() {
859   if (auto mo = memory_order_val()) {
860     if (*mo == ClauseMemoryOrderKind::Acq_rel ||
861         *mo == ClauseMemoryOrderKind::Acquire) {
862       return emitError(
863           "memory-order must not be acq_rel or acquire for atomic updates");
864     }
865   }
866 
867   if (x().getType().cast<PointerLikeType>().getElementType() !=
868       region().getArgument(0).getType()) {
869     return emitError("the type of the operand must be a pointer type whose "
870                      "element type is the same as that of the region argument");
871   }
872 
873   return verifySynchronizationHint(*this, hint_val());
874 }
875 
876 LogicalResult AtomicUpdateOp::verifyRegions() {
877   if (region().getNumArguments() != 1)
878     return emitError("the region must accept exactly one argument");
879 
880   if (region().front().getOperations().size() < 2)
881     return emitError() << "the update region must have at least two operations "
882                           "(binop and terminator)";
883 
884   YieldOp yieldOp = *region().getOps<YieldOp>().begin();
885 
886   if (yieldOp.results().size() != 1)
887     return emitError("only updated value must be returned");
888   if (yieldOp.results().front().getType() != region().getArgument(0).getType())
889     return emitError("input and yielded value must have the same type");
890   return success();
891 }
892 
893 //===----------------------------------------------------------------------===//
894 // Verifier for AtomicCaptureOp
895 //===----------------------------------------------------------------------===//
896 
897 Operation *AtomicCaptureOp::getFirstOp() {
898   return &getRegion().front().getOperations().front();
899 }
900 
901 Operation *AtomicCaptureOp::getSecondOp() {
902   auto &ops = getRegion().front().getOperations();
903   return ops.getNextNode(ops.front());
904 }
905 
906 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
907   if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
908     return op;
909   return dyn_cast<AtomicReadOp>(getSecondOp());
910 }
911 
912 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
913   if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
914     return op;
915   return dyn_cast<AtomicWriteOp>(getSecondOp());
916 }
917 
918 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
919   if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
920     return op;
921   return dyn_cast<AtomicUpdateOp>(getSecondOp());
922 }
923 
924 LogicalResult AtomicCaptureOp::verify() {
925   return verifySynchronizationHint(*this, hint_val());
926 }
927 
928 LogicalResult AtomicCaptureOp::verifyRegions() {
929   Block::OpListType &ops = region().front().getOperations();
930   if (ops.size() != 3)
931     return emitError()
932            << "expected three operations in omp.atomic.capture region (one "
933               "terminator, and two atomic ops)";
934   auto &firstOp = ops.front();
935   auto &secondOp = *ops.getNextNode(firstOp);
936   auto firstReadStmt = dyn_cast<AtomicReadOp>(firstOp);
937   auto firstUpdateStmt = dyn_cast<AtomicUpdateOp>(firstOp);
938   auto secondReadStmt = dyn_cast<AtomicReadOp>(secondOp);
939   auto secondUpdateStmt = dyn_cast<AtomicUpdateOp>(secondOp);
940   auto secondWriteStmt = dyn_cast<AtomicWriteOp>(secondOp);
941 
942   if (!((firstUpdateStmt && secondReadStmt) ||
943         (firstReadStmt && secondUpdateStmt) ||
944         (firstReadStmt && secondWriteStmt)))
945     return ops.front().emitError()
946            << "invalid sequence of operations in the capture region";
947   if (firstUpdateStmt && secondReadStmt &&
948       firstUpdateStmt.x() != secondReadStmt.x())
949     return firstUpdateStmt.emitError()
950            << "updated variable in omp.atomic.update must be captured in "
951               "second operation";
952   if (firstReadStmt && secondUpdateStmt &&
953       firstReadStmt.x() != secondUpdateStmt.x())
954     return firstReadStmt.emitError()
955            << "captured variable in omp.atomic.read must be updated in second "
956               "operation";
957   if (firstReadStmt && secondWriteStmt &&
958       firstReadStmt.x() != secondWriteStmt.address())
959     return firstReadStmt.emitError()
960            << "captured variable in omp.atomic.read must be updated in "
961               "second operation";
962 
963   if (getFirstOp()->getAttr("hint_val") || getSecondOp()->getAttr("hint_val"))
964     return emitOpError(
965         "operations inside capture region must not have hint clause");
966   return success();
967 }
968 
969 #define GET_ATTRDEF_CLASSES
970 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
971 
972 #define GET_OP_CLASSES
973 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
974