1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the OpenMP dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
15 #include "mlir/IR/Attributes.h"
16 #include "mlir/IR/DialectImplementation.h"
17 #include "mlir/IR/OpImplementation.h"
18 #include "mlir/IR/OperationSupport.h"
19 
20 #include "llvm/ADT/BitVector.h"
21 #include "llvm/ADT/SmallString.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/ADT/StringSwitch.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include <cstddef>
27 
28 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
29 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
30 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
31 
32 using namespace mlir;
33 using namespace mlir::omp;
34 
35 namespace {
36 /// Model for pointer-like types that already provide a `getElementType` method.
37 template <typename T>
38 struct PointerLikeModel
39     : public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> {
40   Type getElementType(Type pointer) const {
41     return pointer.cast<T>().getElementType();
42   }
43 };
44 } // namespace
45 
46 void OpenMPDialect::initialize() {
47   addOperations<
48 #define GET_OP_LIST
49 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
50       >();
51   addAttributes<
52 #define GET_ATTRDEF_LIST
53 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
54       >();
55 
56   LLVM::LLVMPointerType::attachInterface<
57       PointerLikeModel<LLVM::LLVMPointerType>>(*getContext());
58   MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext());
59 }
60 
61 //===----------------------------------------------------------------------===//
62 // ParallelOp
63 //===----------------------------------------------------------------------===//
64 
65 void ParallelOp::build(OpBuilder &builder, OperationState &state,
66                        ArrayRef<NamedAttribute> attributes) {
67   ParallelOp::build(
68       builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
69       /*default_val=*/nullptr, /*private_vars=*/ValueRange(),
70       /*firstprivate_vars=*/ValueRange(), /*shared_vars=*/ValueRange(),
71       /*copyin_vars=*/ValueRange(), /*allocate_vars=*/ValueRange(),
72       /*allocators_vars=*/ValueRange(), /*proc_bind_val=*/nullptr);
73   state.addAttributes(attributes);
74 }
75 
76 //===----------------------------------------------------------------------===//
77 // Parser and printer for Operand and type list
78 //===----------------------------------------------------------------------===//
79 
80 /// Parse a list of operands with types.
81 ///
82 /// operand-and-type-list ::= `(` ssa-id-and-type-list `)`
83 /// ssa-id-and-type-list ::= ssa-id-and-type |
84 ///                          ssa-id-and-type `,` ssa-id-and-type-list
85 /// ssa-id-and-type ::= ssa-id `:` type
86 static ParseResult
87 parseOperandAndTypeList(OpAsmParser &parser,
88                         SmallVectorImpl<OpAsmParser::OperandType> &operands,
89                         SmallVectorImpl<Type> &types) {
90   return parser.parseCommaSeparatedList(
91       OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
92         OpAsmParser::OperandType operand;
93         Type type;
94         if (parser.parseOperand(operand) || parser.parseColonType(type))
95           return failure();
96         operands.push_back(operand);
97         types.push_back(type);
98         return success();
99       });
100 }
101 
102 /// Print an operand and type list with parentheses
103 static void printOperandAndTypeList(OpAsmPrinter &p, OperandRange operands) {
104   p << "(";
105   llvm::interleaveComma(
106       operands, p, [&](const Value &v) { p << v << " : " << v.getType(); });
107   p << ") ";
108 }
109 
110 /// Print data variables corresponding to a data-sharing clause `name`
111 static void printDataVars(OpAsmPrinter &p, OperandRange operands,
112                           StringRef name) {
113   if (!operands.empty()) {
114     p << name;
115     printOperandAndTypeList(p, operands);
116   }
117 }
118 
119 //===----------------------------------------------------------------------===//
120 // Parser and printer for Allocate Clause
121 //===----------------------------------------------------------------------===//
122 
123 /// Parse an allocate clause with allocators and a list of operands with types.
124 ///
125 /// allocate ::= `allocate` `(` allocate-operand-list `)`
126 /// allocate-operand-list :: = allocate-operand |
127 ///                            allocator-operand `,` allocate-operand-list
128 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
129 /// ssa-id-and-type ::= ssa-id `:` type
130 static ParseResult parseAllocateAndAllocator(
131     OpAsmParser &parser,
132     SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate,
133     SmallVectorImpl<Type> &typesAllocate,
134     SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator,
135     SmallVectorImpl<Type> &typesAllocator) {
136 
137   return parser.parseCommaSeparatedList(
138       OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
139         OpAsmParser::OperandType operand;
140         Type type;
141         if (parser.parseOperand(operand) || parser.parseColonType(type))
142           return failure();
143         operandsAllocator.push_back(operand);
144         typesAllocator.push_back(type);
145         if (parser.parseArrow())
146           return failure();
147         if (parser.parseOperand(operand) || parser.parseColonType(type))
148           return failure();
149 
150         operandsAllocate.push_back(operand);
151         typesAllocate.push_back(type);
152         return success();
153       });
154 }
155 
156 /// Print allocate clause
157 static void printAllocateAndAllocator(OpAsmPrinter &p,
158                                       OperandRange varsAllocate,
159                                       OperandRange varsAllocator) {
160   p << "allocate(";
161   for (unsigned i = 0; i < varsAllocate.size(); ++i) {
162     std::string separator = i == varsAllocate.size() - 1 ? ") " : ", ";
163     p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> ";
164     p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator;
165   }
166 }
167 
168 LogicalResult ParallelOp::verify() {
169   if (allocate_vars().size() != allocators_vars().size())
170     return emitError(
171         "expected equal sizes for allocate and allocator variables");
172   return success();
173 }
174 
175 void ParallelOp::print(OpAsmPrinter &p) {
176   p << " ";
177   if (auto ifCond = if_expr_var())
178     p << "if(" << ifCond << " : " << ifCond.getType() << ") ";
179 
180   if (auto threads = num_threads_var())
181     p << "num_threads(" << threads << " : " << threads.getType() << ") ";
182 
183   printDataVars(p, private_vars(), "private");
184   printDataVars(p, firstprivate_vars(), "firstprivate");
185   printDataVars(p, shared_vars(), "shared");
186   printDataVars(p, copyin_vars(), "copyin");
187 
188   if (!allocate_vars().empty())
189     printAllocateAndAllocator(p, allocate_vars(), allocators_vars());
190 
191   if (auto def = default_val())
192     p << "default(" << stringifyClauseDefault(*def).drop_front(3) << ") ";
193 
194   if (auto bind = proc_bind_val())
195     p << "proc_bind(" << stringifyClauseProcBindKind(*bind) << ") ";
196 
197   p << ' ';
198   p.printRegion(getRegion());
199 }
200 
201 void TargetOp::print(OpAsmPrinter &p) {
202   p << " ";
203   if (auto ifCond = if_expr())
204     p << "if(" << ifCond << " : " << ifCond.getType() << ") ";
205 
206   if (auto device = this->device())
207     p << "device(" << device << " : " << device.getType() << ") ";
208 
209   if (auto threads = thread_limit())
210     p << "thread_limit(" << threads << " : " << threads.getType() << ") ";
211 
212   if (nowait())
213     p << "nowait ";
214 
215   p.printRegion(getRegion());
216 }
217 
218 //===----------------------------------------------------------------------===//
219 // Parser and printer for Linear Clause
220 //===----------------------------------------------------------------------===//
221 
222 /// linear ::= `linear` `(` linear-list `)`
223 /// linear-list := linear-val | linear-val linear-list
224 /// linear-val := ssa-id-and-type `=` ssa-id-and-type
225 static ParseResult
226 parseLinearClause(OpAsmParser &parser,
227                   SmallVectorImpl<OpAsmParser::OperandType> &vars,
228                   SmallVectorImpl<Type> &types,
229                   SmallVectorImpl<OpAsmParser::OperandType> &stepVars) {
230   if (parser.parseLParen())
231     return failure();
232 
233   do {
234     OpAsmParser::OperandType var;
235     Type type;
236     OpAsmParser::OperandType stepVar;
237     if (parser.parseOperand(var) || parser.parseEqual() ||
238         parser.parseOperand(stepVar) || parser.parseColonType(type))
239       return failure();
240 
241     vars.push_back(var);
242     types.push_back(type);
243     stepVars.push_back(stepVar);
244   } while (succeeded(parser.parseOptionalComma()));
245 
246   if (parser.parseRParen())
247     return failure();
248 
249   return success();
250 }
251 
252 /// Print Linear Clause
253 static void printLinearClause(OpAsmPrinter &p, OperandRange linearVars,
254                               OperandRange linearStepVars) {
255   size_t linearVarsSize = linearVars.size();
256   p << "linear(";
257   for (unsigned i = 0; i < linearVarsSize; ++i) {
258     std::string separator = i == linearVarsSize - 1 ? ") " : ", ";
259     p << linearVars[i];
260     if (linearStepVars.size() > i)
261       p << " = " << linearStepVars[i];
262     p << " : " << linearVars[i].getType() << separator;
263   }
264 }
265 
266 //===----------------------------------------------------------------------===//
267 // Parser and printer for Schedule Clause
268 //===----------------------------------------------------------------------===//
269 
270 static ParseResult
271 verifyScheduleModifiers(OpAsmParser &parser,
272                         SmallVectorImpl<SmallString<12>> &modifiers) {
273   if (modifiers.size() > 2)
274     return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)";
275   for (const auto &mod : modifiers) {
276     // Translate the string. If it has no value, then it was not a valid
277     // modifier!
278     auto symbol = symbolizeScheduleModifier(mod);
279     if (!symbol.hasValue())
280       return parser.emitError(parser.getNameLoc())
281              << " unknown modifier type: " << mod;
282   }
283 
284   // If we have one modifier that is "simd", then stick a "none" modiifer in
285   // index 0.
286   if (modifiers.size() == 1) {
287     if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
288       modifiers.push_back(modifiers[0]);
289       modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
290     }
291   } else if (modifiers.size() == 2) {
292     // If there are two modifier:
293     // First modifier should not be simd, second one should be simd
294     if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
295         symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
296       return parser.emitError(parser.getNameLoc())
297              << " incorrect modifier order";
298   }
299   return success();
300 }
301 
302 /// schedule ::= `schedule` `(` sched-list `)`
303 /// sched-list ::= sched-val | sched-val sched-list |
304 ///                sched-val `,` sched-modifier
305 /// sched-val ::= sched-with-chunk | sched-wo-chunk
306 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
307 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
308 /// sched-wo-chunk ::=  `auto` | `runtime`
309 /// sched-modifier ::=  sched-mod-val | sched-mod-val `,` sched-mod-val
310 /// sched-mod-val ::=  `monotonic` | `nonmonotonic` | `simd` | `none`
311 static ParseResult
312 parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule,
313                     SmallVectorImpl<SmallString<12>> &modifiers,
314                     Optional<OpAsmParser::OperandType> &chunkSize,
315                     Type &chunkType) {
316   if (parser.parseLParen())
317     return failure();
318 
319   StringRef keyword;
320   if (parser.parseKeyword(&keyword))
321     return failure();
322 
323   schedule = keyword;
324   if (keyword == "static" || keyword == "dynamic" || keyword == "guided") {
325     if (succeeded(parser.parseOptionalEqual())) {
326       chunkSize = OpAsmParser::OperandType{};
327       if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType))
328         return failure();
329     } else {
330       chunkSize = llvm::NoneType::None;
331     }
332   } else if (keyword == "auto" || keyword == "runtime") {
333     chunkSize = llvm::NoneType::None;
334   } else {
335     return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
336   }
337 
338   // If there is a comma, we have one or more modifiers..
339   while (succeeded(parser.parseOptionalComma())) {
340     StringRef mod;
341     if (parser.parseKeyword(&mod))
342       return failure();
343     modifiers.push_back(mod);
344   }
345 
346   if (parser.parseRParen())
347     return failure();
348 
349   if (verifyScheduleModifiers(parser, modifiers))
350     return failure();
351 
352   return success();
353 }
354 
355 /// Print schedule clause
356 static void printScheduleClause(OpAsmPrinter &p, ClauseScheduleKind sched,
357                                 Optional<ScheduleModifier> modifier, bool simd,
358                                 Value scheduleChunkVar) {
359   p << "schedule(" << stringifyClauseScheduleKind(sched).lower();
360   if (scheduleChunkVar)
361     p << " = " << scheduleChunkVar << " : " << scheduleChunkVar.getType();
362   if (modifier)
363     p << ", " << stringifyScheduleModifier(*modifier);
364   if (simd)
365     p << ", simd";
366   p << ") ";
367 }
368 
369 //===----------------------------------------------------------------------===//
370 // Parser, printer and verifier for ReductionVarList
371 //===----------------------------------------------------------------------===//
372 
373 /// reduction ::= `reduction` `(` reduction-entry-list `)`
374 /// reduction-entry-list ::= reduction-entry
375 ///                        | reduction-entry-list `,` reduction-entry
376 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type
377 static ParseResult
378 parseReductionVarList(OpAsmParser &parser,
379                       SmallVectorImpl<SymbolRefAttr> &symbols,
380                       SmallVectorImpl<OpAsmParser::OperandType> &operands,
381                       SmallVectorImpl<Type> &types) {
382   if (failed(parser.parseLParen()))
383     return failure();
384 
385   do {
386     if (parser.parseAttribute(symbols.emplace_back()) || parser.parseArrow() ||
387         parser.parseOperand(operands.emplace_back()) ||
388         parser.parseColonType(types.emplace_back()))
389       return failure();
390   } while (succeeded(parser.parseOptionalComma()));
391   return parser.parseRParen();
392 }
393 
394 /// Print Reduction clause
395 static void printReductionVarList(OpAsmPrinter &p,
396                                   Optional<ArrayAttr> reductions,
397                                   OperandRange reductionVars) {
398   p << "reduction(";
399   for (unsigned i = 0, e = reductions->size(); i < e; ++i) {
400     if (i != 0)
401       p << ", ";
402     p << (*reductions)[i] << " -> " << reductionVars[i] << " : "
403       << reductionVars[i].getType();
404   }
405   p << ") ";
406 }
407 
408 /// Verifies Reduction Clause
409 static LogicalResult verifyReductionVarList(Operation *op,
410                                             Optional<ArrayAttr> reductions,
411                                             OperandRange reductionVars) {
412   if (!reductionVars.empty()) {
413     if (!reductions || reductions->size() != reductionVars.size())
414       return op->emitOpError()
415              << "expected as many reduction symbol references "
416                 "as reduction variables";
417   } else {
418     if (reductions)
419       return op->emitOpError() << "unexpected reduction symbol references";
420     return success();
421   }
422 
423   DenseSet<Value> accumulators;
424   for (auto args : llvm::zip(reductionVars, *reductions)) {
425     Value accum = std::get<0>(args);
426 
427     if (!accumulators.insert(accum).second)
428       return op->emitOpError() << "accumulator variable used more than once";
429 
430     Type varType = accum.getType().cast<PointerLikeType>();
431     auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>();
432     auto decl =
433         SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef);
434     if (!decl)
435       return op->emitOpError() << "expected symbol reference " << symbolRef
436                                << " to point to a reduction declaration";
437 
438     if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
439       return op->emitOpError()
440              << "expected accumulator (" << varType
441              << ") to be the same type as reduction declaration ("
442              << decl.getAccumulatorType() << ")";
443   }
444 
445   return success();
446 }
447 
448 //===----------------------------------------------------------------------===//
449 // Parser, printer and verifier for Synchronization Hint (2.17.12)
450 //===----------------------------------------------------------------------===//
451 
452 /// Parses a Synchronization Hint clause. The value of hint is an integer
453 /// which is a combination of different hints from `omp_sync_hint_t`.
454 ///
455 /// hint-clause = `hint` `(` hint-value `)`
456 static ParseResult parseSynchronizationHint(OpAsmParser &parser,
457                                             IntegerAttr &hintAttr,
458                                             bool parseKeyword = true) {
459   if (parseKeyword && failed(parser.parseOptionalKeyword("hint"))) {
460     hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
461     return success();
462   }
463 
464   if (failed(parser.parseLParen()))
465     return failure();
466   StringRef hintKeyword;
467   int64_t hint = 0;
468   do {
469     if (failed(parser.parseKeyword(&hintKeyword)))
470       return failure();
471     if (hintKeyword == "uncontended")
472       hint |= 1;
473     else if (hintKeyword == "contended")
474       hint |= 2;
475     else if (hintKeyword == "nonspeculative")
476       hint |= 4;
477     else if (hintKeyword == "speculative")
478       hint |= 8;
479     else
480       return parser.emitError(parser.getCurrentLocation())
481              << hintKeyword << " is not a valid hint";
482   } while (succeeded(parser.parseOptionalComma()));
483   if (failed(parser.parseRParen()))
484     return failure();
485   hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
486   return success();
487 }
488 
489 /// Prints a Synchronization Hint clause
490 static void printSynchronizationHint(OpAsmPrinter &p, Operation *op,
491                                      IntegerAttr hintAttr) {
492   int64_t hint = hintAttr.getInt();
493 
494   if (hint == 0)
495     return;
496 
497   // Helper function to get n-th bit from the right end of `value`
498   auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
499 
500   bool uncontended = bitn(hint, 0);
501   bool contended = bitn(hint, 1);
502   bool nonspeculative = bitn(hint, 2);
503   bool speculative = bitn(hint, 3);
504 
505   SmallVector<StringRef> hints;
506   if (uncontended)
507     hints.push_back("uncontended");
508   if (contended)
509     hints.push_back("contended");
510   if (nonspeculative)
511     hints.push_back("nonspeculative");
512   if (speculative)
513     hints.push_back("speculative");
514 
515   p << "hint(";
516   llvm::interleaveComma(hints, p);
517   p << ") ";
518 }
519 
520 /// Verifies a synchronization hint clause
521 static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
522 
523   // Helper function to get n-th bit from the right end of `value`
524   auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
525 
526   bool uncontended = bitn(hint, 0);
527   bool contended = bitn(hint, 1);
528   bool nonspeculative = bitn(hint, 2);
529   bool speculative = bitn(hint, 3);
530 
531   if (uncontended && contended)
532     return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
533                                 "omp_sync_hint_contended cannot be combined";
534   if (nonspeculative && speculative)
535     return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
536                                 "omp_sync_hint_speculative cannot be combined.";
537   return success();
538 }
539 
540 enum ClauseType {
541   ifClause,
542   numThreadsClause,
543   deviceClause,
544   threadLimitClause,
545   privateClause,
546   firstprivateClause,
547   lastprivateClause,
548   sharedClause,
549   copyinClause,
550   allocateClause,
551   defaultClause,
552   procBindClause,
553   reductionClause,
554   nowaitClause,
555   linearClause,
556   scheduleClause,
557   collapseClause,
558   orderClause,
559   orderedClause,
560   memoryOrderClause,
561   hintClause,
562   COUNT
563 };
564 
565 //===----------------------------------------------------------------------===//
566 // Parser for Clause List
567 //===----------------------------------------------------------------------===//
568 
569 /// Parse a clause attribute `(` $value `)`.
570 template <typename ClauseAttr>
571 static ParseResult parseClauseAttr(AsmParser &parser, OperationState &state,
572                                    StringRef attrName, StringRef name) {
573   using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
574   StringRef enumStr;
575   SMLoc loc = parser.getCurrentLocation();
576   if (parser.parseLParen() || parser.parseKeyword(&enumStr) ||
577       parser.parseRParen())
578     return failure();
579   if (Optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
580     auto attr = ClauseAttr::get(parser.getContext(), *enumValue);
581     state.addAttribute(attrName, attr);
582     return success();
583   }
584   return parser.emitError(loc, "invalid ") << name << " kind";
585 }
586 
587 /// Parse a list of clauses. The clauses can appear in any order, but their
588 /// operand segment indices are in the same order that they are passed in the
589 /// `clauses` list. The operand segments are added over the prevSegments
590 
591 /// clause-list ::= clause clause-list | empty
592 /// clause ::= if | num-threads | private | firstprivate | lastprivate |
593 ///            shared | copyin | allocate | default | proc-bind | reduction |
594 ///            nowait | linear | schedule | collapse | order | ordered |
595 ///            inclusive
596 /// if ::= `if` `(` ssa-id-and-type `)`
597 /// num-threads ::= `num_threads` `(` ssa-id-and-type `)`
598 /// private ::= `private` operand-and-type-list
599 /// firstprivate ::= `firstprivate` operand-and-type-list
600 /// lastprivate ::= `lastprivate` operand-and-type-list
601 /// shared ::= `shared` operand-and-type-list
602 /// copyin ::= `copyin` operand-and-type-list
603 /// allocate ::= `allocate` `(` allocate-operand-list `)`
604 /// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`)
605 /// proc-bind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)`
606 /// reduction ::= `reduction` `(` reduction-entry-list `)`
607 /// nowait ::= `nowait`
608 /// linear ::= `linear` `(` linear-list `)`
609 /// schedule ::= `schedule` `(` sched-list `)`
610 /// collapse ::= `collapse` `(` ssa-id-and-type `)`
611 /// order ::= `order` `(` `concurrent` `)`
612 /// ordered ::= `ordered` `(` ssa-id-and-type `)`
613 /// inclusive ::= `inclusive`
614 ///
615 /// Note that each clause can only appear once in the clase-list.
616 static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
617                                 SmallVectorImpl<ClauseType> &clauses,
618                                 SmallVectorImpl<int> &segments) {
619 
620   // Check done[clause] to see if it has been parsed already
621   BitVector done(ClauseType::COUNT, false);
622 
623   // See pos[clause] to get position of clause in operand segments
624   SmallVector<int> pos(ClauseType::COUNT, -1);
625 
626   // Stores the last parsed clause keyword
627   StringRef clauseKeyword;
628   StringRef opName = result.name.getStringRef();
629 
630   // Containers for storing operands, types and attributes for various clauses
631   std::pair<OpAsmParser::OperandType, Type> ifCond;
632   std::pair<OpAsmParser::OperandType, Type> numThreads;
633   std::pair<OpAsmParser::OperandType, Type> device;
634   std::pair<OpAsmParser::OperandType, Type> threadLimit;
635 
636   SmallVector<OpAsmParser::OperandType> privates, firstprivates, lastprivates,
637       shareds, copyins;
638   SmallVector<Type> privateTypes, firstprivateTypes, lastprivateTypes,
639       sharedTypes, copyinTypes;
640 
641   SmallVector<OpAsmParser::OperandType> allocates, allocators;
642   SmallVector<Type> allocateTypes, allocatorTypes;
643 
644   SmallVector<SymbolRefAttr> reductionSymbols;
645   SmallVector<OpAsmParser::OperandType> reductionVars;
646   SmallVector<Type> reductionVarTypes;
647 
648   SmallVector<OpAsmParser::OperandType> linears;
649   SmallVector<Type> linearTypes;
650   SmallVector<OpAsmParser::OperandType> linearSteps;
651 
652   SmallString<8> schedule;
653   SmallVector<SmallString<12>> modifiers;
654   Optional<OpAsmParser::OperandType> scheduleChunkSize;
655   Type scheduleChunkType;
656 
657   // Compute the position of clauses in operand segments
658   int currPos = 0;
659   for (ClauseType clause : clauses) {
660 
661     // Skip the following clauses - they do not take any position in operand
662     // segments
663     if (clause == defaultClause || clause == procBindClause ||
664         clause == nowaitClause || clause == collapseClause ||
665         clause == orderClause || clause == orderedClause)
666       continue;
667 
668     pos[clause] = currPos++;
669 
670     // For the following clauses, two positions are reserved in the operand
671     // segments
672     if (clause == allocateClause || clause == linearClause)
673       currPos++;
674   }
675 
676   SmallVector<int> clauseSegments(currPos);
677 
678   // Helper function to check if a clause is allowed/repeated or not
679   auto checkAllowed = [&](ClauseType clause) -> ParseResult {
680     if (!llvm::is_contained(clauses, clause))
681       return parser.emitError(parser.getCurrentLocation())
682              << clauseKeyword << " is not a valid clause for the " << opName
683              << " operation";
684     if (done[clause])
685       return parser.emitError(parser.getCurrentLocation())
686              << "at most one " << clauseKeyword << " clause can appear on the "
687              << opName << " operation";
688     done[clause] = true;
689     return success();
690   };
691 
692   while (succeeded(parser.parseOptionalKeyword(&clauseKeyword))) {
693     if (clauseKeyword == "if") {
694       if (checkAllowed(ifClause) || parser.parseLParen() ||
695           parser.parseOperand(ifCond.first) ||
696           parser.parseColonType(ifCond.second) || parser.parseRParen())
697         return failure();
698       clauseSegments[pos[ifClause]] = 1;
699     } else if (clauseKeyword == "num_threads") {
700       if (checkAllowed(numThreadsClause) || parser.parseLParen() ||
701           parser.parseOperand(numThreads.first) ||
702           parser.parseColonType(numThreads.second) || parser.parseRParen())
703         return failure();
704       clauseSegments[pos[numThreadsClause]] = 1;
705     } else if (clauseKeyword == "device") {
706       if (checkAllowed(deviceClause) || parser.parseLParen() ||
707           parser.parseOperand(device.first) ||
708           parser.parseColonType(device.second) || parser.parseRParen())
709         return failure();
710       clauseSegments[pos[deviceClause]] = 1;
711     } else if (clauseKeyword == "thread_limit") {
712       if (checkAllowed(threadLimitClause) || parser.parseLParen() ||
713           parser.parseOperand(threadLimit.first) ||
714           parser.parseColonType(threadLimit.second) || parser.parseRParen())
715         return failure();
716       clauseSegments[pos[threadLimitClause]] = 1;
717     } else if (clauseKeyword == "private") {
718       if (checkAllowed(privateClause) ||
719           parseOperandAndTypeList(parser, privates, privateTypes))
720         return failure();
721       clauseSegments[pos[privateClause]] = privates.size();
722     } else if (clauseKeyword == "firstprivate") {
723       if (checkAllowed(firstprivateClause) ||
724           parseOperandAndTypeList(parser, firstprivates, firstprivateTypes))
725         return failure();
726       clauseSegments[pos[firstprivateClause]] = firstprivates.size();
727     } else if (clauseKeyword == "lastprivate") {
728       if (checkAllowed(lastprivateClause) ||
729           parseOperandAndTypeList(parser, lastprivates, lastprivateTypes))
730         return failure();
731       clauseSegments[pos[lastprivateClause]] = lastprivates.size();
732     } else if (clauseKeyword == "shared") {
733       if (checkAllowed(sharedClause) ||
734           parseOperandAndTypeList(parser, shareds, sharedTypes))
735         return failure();
736       clauseSegments[pos[sharedClause]] = shareds.size();
737     } else if (clauseKeyword == "copyin") {
738       if (checkAllowed(copyinClause) ||
739           parseOperandAndTypeList(parser, copyins, copyinTypes))
740         return failure();
741       clauseSegments[pos[copyinClause]] = copyins.size();
742     } else if (clauseKeyword == "allocate") {
743       if (checkAllowed(allocateClause) ||
744           parseAllocateAndAllocator(parser, allocates, allocateTypes,
745                                     allocators, allocatorTypes))
746         return failure();
747       clauseSegments[pos[allocateClause]] = allocates.size();
748       clauseSegments[pos[allocateClause] + 1] = allocators.size();
749     } else if (clauseKeyword == "default") {
750       StringRef defval;
751       SMLoc loc = parser.getCurrentLocation();
752       if (checkAllowed(defaultClause) || parser.parseLParen() ||
753           parser.parseKeyword(&defval) || parser.parseRParen())
754         return failure();
755       // The def prefix is required for the attribute as "private" is a keyword
756       // in C++.
757       if (Optional<ClauseDefault> def =
758               symbolizeClauseDefault(("def" + defval).str())) {
759         result.addAttribute("default_val",
760                             ClauseDefaultAttr::get(parser.getContext(), *def));
761       } else {
762         return parser.emitError(loc, "invalid default clause");
763       }
764     } else if (clauseKeyword == "proc_bind") {
765       if (checkAllowed(procBindClause) ||
766           parseClauseAttr<ClauseProcBindKindAttr>(parser, result,
767                                                   "proc_bind_val", "proc bind"))
768         return failure();
769     } else if (clauseKeyword == "reduction") {
770       if (checkAllowed(reductionClause) ||
771           parseReductionVarList(parser, reductionSymbols, reductionVars,
772                                 reductionVarTypes))
773         return failure();
774       clauseSegments[pos[reductionClause]] = reductionVars.size();
775     } else if (clauseKeyword == "nowait") {
776       if (checkAllowed(nowaitClause))
777         return failure();
778       auto attr = UnitAttr::get(parser.getBuilder().getContext());
779       result.addAttribute("nowait", attr);
780     } else if (clauseKeyword == "linear") {
781       if (checkAllowed(linearClause) ||
782           parseLinearClause(parser, linears, linearTypes, linearSteps))
783         return failure();
784       clauseSegments[pos[linearClause]] = linears.size();
785       clauseSegments[pos[linearClause] + 1] = linearSteps.size();
786     } else if (clauseKeyword == "schedule") {
787       if (checkAllowed(scheduleClause) ||
788           parseScheduleClause(parser, schedule, modifiers, scheduleChunkSize,
789                               scheduleChunkType))
790         return failure();
791       if (scheduleChunkSize) {
792         clauseSegments[pos[scheduleClause]] = 1;
793       }
794     } else if (clauseKeyword == "collapse") {
795       auto type = parser.getBuilder().getI64Type();
796       mlir::IntegerAttr attr;
797       if (checkAllowed(collapseClause) || parser.parseLParen() ||
798           parser.parseAttribute(attr, type) || parser.parseRParen())
799         return failure();
800       result.addAttribute("collapse_val", attr);
801     } else if (clauseKeyword == "ordered") {
802       mlir::IntegerAttr attr;
803       if (checkAllowed(orderedClause))
804         return failure();
805       if (succeeded(parser.parseOptionalLParen())) {
806         auto type = parser.getBuilder().getI64Type();
807         if (parser.parseAttribute(attr, type) || parser.parseRParen())
808           return failure();
809       } else {
810         // Use 0 to represent no ordered parameter was specified
811         attr = parser.getBuilder().getI64IntegerAttr(0);
812       }
813       result.addAttribute("ordered_val", attr);
814     } else if (clauseKeyword == "order") {
815       if (checkAllowed(orderClause) ||
816           parseClauseAttr<ClauseOrderKindAttr>(parser, result, "order_val",
817                                                "order"))
818         return failure();
819     } else if (clauseKeyword == "memory_order") {
820       if (checkAllowed(memoryOrderClause) ||
821           parseClauseAttr<ClauseMemoryOrderKindAttr>(
822               parser, result, "memory_order", "memory order"))
823         return failure();
824     } else if (clauseKeyword == "hint") {
825       IntegerAttr hint;
826       if (checkAllowed(hintClause) ||
827           parseSynchronizationHint(parser, hint, false))
828         return failure();
829       result.addAttribute("hint", hint);
830     } else {
831       return parser.emitError(parser.getNameLoc())
832              << clauseKeyword << " is not a valid clause";
833     }
834   }
835 
836   // Add if parameter.
837   if (done[ifClause] && clauseSegments[pos[ifClause]] &&
838       failed(
839           parser.resolveOperand(ifCond.first, ifCond.second, result.operands)))
840     return failure();
841 
842   // Add num_threads parameter.
843   if (done[numThreadsClause] && clauseSegments[pos[numThreadsClause]] &&
844       failed(parser.resolveOperand(numThreads.first, numThreads.second,
845                                    result.operands)))
846     return failure();
847 
848   // Add device parameter.
849   if (done[deviceClause] && clauseSegments[pos[deviceClause]] &&
850       failed(
851           parser.resolveOperand(device.first, device.second, result.operands)))
852     return failure();
853 
854   // Add thread_limit parameter.
855   if (done[threadLimitClause] && clauseSegments[pos[threadLimitClause]] &&
856       failed(parser.resolveOperand(threadLimit.first, threadLimit.second,
857                                    result.operands)))
858     return failure();
859 
860   // Add private parameters.
861   if (done[privateClause] && clauseSegments[pos[privateClause]] &&
862       failed(parser.resolveOperands(privates, privateTypes,
863                                     privates[0].location, result.operands)))
864     return failure();
865 
866   // Add firstprivate parameters.
867   if (done[firstprivateClause] && clauseSegments[pos[firstprivateClause]] &&
868       failed(parser.resolveOperands(firstprivates, firstprivateTypes,
869                                     firstprivates[0].location,
870                                     result.operands)))
871     return failure();
872 
873   // Add lastprivate parameters.
874   if (done[lastprivateClause] && clauseSegments[pos[lastprivateClause]] &&
875       failed(parser.resolveOperands(lastprivates, lastprivateTypes,
876                                     lastprivates[0].location, result.operands)))
877     return failure();
878 
879   // Add shared parameters.
880   if (done[sharedClause] && clauseSegments[pos[sharedClause]] &&
881       failed(parser.resolveOperands(shareds, sharedTypes, shareds[0].location,
882                                     result.operands)))
883     return failure();
884 
885   // Add copyin parameters.
886   if (done[copyinClause] && clauseSegments[pos[copyinClause]] &&
887       failed(parser.resolveOperands(copyins, copyinTypes, copyins[0].location,
888                                     result.operands)))
889     return failure();
890 
891   // Add allocate parameters.
892   if (done[allocateClause] && clauseSegments[pos[allocateClause]] &&
893       failed(parser.resolveOperands(allocates, allocateTypes,
894                                     allocates[0].location, result.operands)))
895     return failure();
896 
897   // Add allocator parameters.
898   if (done[allocateClause] && clauseSegments[pos[allocateClause] + 1] &&
899       failed(parser.resolveOperands(allocators, allocatorTypes,
900                                     allocators[0].location, result.operands)))
901     return failure();
902 
903   // Add reduction parameters and symbols
904   if (done[reductionClause] && clauseSegments[pos[reductionClause]]) {
905     if (failed(parser.resolveOperands(reductionVars, reductionVarTypes,
906                                       parser.getNameLoc(), result.operands)))
907       return failure();
908 
909     SmallVector<Attribute> reductions(reductionSymbols.begin(),
910                                       reductionSymbols.end());
911     result.addAttribute("reductions",
912                         parser.getBuilder().getArrayAttr(reductions));
913   }
914 
915   // Add linear parameters
916   if (done[linearClause] && clauseSegments[pos[linearClause]]) {
917     auto linearStepType = parser.getBuilder().getI32Type();
918     SmallVector<Type> linearStepTypes(linearSteps.size(), linearStepType);
919     if (failed(parser.resolveOperands(linears, linearTypes, linears[0].location,
920                                       result.operands)) ||
921         failed(parser.resolveOperands(linearSteps, linearStepTypes,
922                                       linearSteps[0].location,
923                                       result.operands)))
924       return failure();
925   }
926 
927   // Add schedule parameters
928   if (done[scheduleClause] && !schedule.empty()) {
929     schedule[0] = llvm::toUpper(schedule[0]);
930     if (Optional<ClauseScheduleKind> sched =
931             symbolizeClauseScheduleKind(schedule)) {
932       auto attr = ClauseScheduleKindAttr::get(parser.getContext(), *sched);
933       result.addAttribute("schedule_val", attr);
934     } else {
935       return parser.emitError(parser.getCurrentLocation(),
936                               "invalid schedule kind");
937     }
938     if (!modifiers.empty()) {
939       SMLoc loc = parser.getCurrentLocation();
940       if (Optional<ScheduleModifier> mod =
941               symbolizeScheduleModifier(modifiers[0])) {
942         result.addAttribute(
943             "schedule_modifier",
944             ScheduleModifierAttr::get(parser.getContext(), *mod));
945       } else {
946         return parser.emitError(loc, "invalid schedule modifier");
947       }
948       // Only SIMD attribute is allowed here!
949       if (modifiers.size() > 1) {
950         assert(symbolizeScheduleModifier(modifiers[1]) ==
951                ScheduleModifier::simd);
952         auto attr = UnitAttr::get(parser.getBuilder().getContext());
953         result.addAttribute("simd_modifier", attr);
954       }
955     }
956     if (scheduleChunkSize)
957       parser.resolveOperand(*scheduleChunkSize, scheduleChunkType,
958                             result.operands);
959   }
960 
961   segments.insert(segments.end(), clauseSegments.begin(), clauseSegments.end());
962 
963   return success();
964 }
965 
966 /// Parses a parallel operation.
967 ///
968 /// operation ::= `omp.parallel` clause-list
969 /// clause-list ::= clause | clause clause-list
970 /// clause ::= if | num-threads | private | firstprivate | shared | copyin |
971 ///            allocate | default | proc-bind
972 ///
973 ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
974   SmallVector<ClauseType> clauses = {
975       ifClause,           numThreadsClause, privateClause,
976       firstprivateClause, sharedClause,     copyinClause,
977       allocateClause,     defaultClause,    procBindClause};
978 
979   SmallVector<int> segments;
980 
981   if (failed(parseClauses(parser, result, clauses, segments)))
982     return failure();
983 
984   result.addAttribute("operand_segment_sizes",
985                       parser.getBuilder().getI32VectorAttr(segments));
986 
987   Region *body = result.addRegion();
988   SmallVector<OpAsmParser::OperandType> regionArgs;
989   SmallVector<Type> regionArgTypes;
990   if (parser.parseRegion(*body, regionArgs, regionArgTypes))
991     return failure();
992   return success();
993 }
994 
995 /// Parses a target operation.
996 ///
997 /// operation ::= `omp.target` clause-list
998 /// clause-list ::= clause | clause clause-list
999 /// clause ::= if | device | thread_limit | nowait
1000 ///
1001 ParseResult TargetOp::parse(OpAsmParser &parser, OperationState &result) {
1002   SmallVector<ClauseType> clauses = {ifClause, deviceClause, threadLimitClause,
1003                                      nowaitClause};
1004 
1005   SmallVector<int> segments;
1006 
1007   if (failed(parseClauses(parser, result, clauses, segments)))
1008     return failure();
1009 
1010   result.addAttribute(
1011       TargetOp::AttrSizedOperandSegments::getOperandSegmentSizeAttr(),
1012       parser.getBuilder().getI32VectorAttr(segments));
1013 
1014   Region *body = result.addRegion();
1015   SmallVector<OpAsmParser::OperandType> regionArgs;
1016   SmallVector<Type> regionArgTypes;
1017   if (parser.parseRegion(*body, regionArgs, regionArgTypes))
1018     return failure();
1019   return success();
1020 }
1021 
1022 //===----------------------------------------------------------------------===//
1023 // Parser, printer and verifier for SectionsOp
1024 //===----------------------------------------------------------------------===//
1025 
1026 /// Parses an OpenMP Sections operation
1027 ///
1028 /// sections ::= `omp.sections` clause-list
1029 /// clause-list ::= clause clause-list | empty
1030 /// clause ::= private | firstprivate | lastprivate | reduction | allocate |
1031 ///            nowait
1032 ParseResult SectionsOp::parse(OpAsmParser &parser, OperationState &result) {
1033   SmallVector<ClauseType> clauses = {privateClause,     firstprivateClause,
1034                                      lastprivateClause, reductionClause,
1035                                      allocateClause,    nowaitClause};
1036 
1037   SmallVector<int> segments;
1038 
1039   if (failed(parseClauses(parser, result, clauses, segments)))
1040     return failure();
1041 
1042   result.addAttribute("operand_segment_sizes",
1043                       parser.getBuilder().getI32VectorAttr(segments));
1044 
1045   // Now parse the body.
1046   Region *body = result.addRegion();
1047   if (parser.parseRegion(*body))
1048     return failure();
1049   return success();
1050 }
1051 
1052 void SectionsOp::print(OpAsmPrinter &p) {
1053   p << " ";
1054   printDataVars(p, private_vars(), "private");
1055   printDataVars(p, firstprivate_vars(), "firstprivate");
1056   printDataVars(p, lastprivate_vars(), "lastprivate");
1057 
1058   if (!reduction_vars().empty())
1059     printReductionVarList(p, reductions(), reduction_vars());
1060 
1061   if (!allocate_vars().empty())
1062     printAllocateAndAllocator(p, allocate_vars(), allocators_vars());
1063 
1064   if (nowait())
1065     p << "nowait";
1066 
1067   p << ' ';
1068   p.printRegion(region());
1069 }
1070 
1071 LogicalResult SectionsOp::verify() {
1072   // A list item may not appear in more than one clause on the same directive,
1073   // except that it may be specified in both firstprivate and lastprivate
1074   // clauses.
1075   for (auto var : private_vars()) {
1076     if (llvm::is_contained(firstprivate_vars(), var))
1077       return emitOpError()
1078              << "operand used in both private and firstprivate clauses";
1079     if (llvm::is_contained(lastprivate_vars(), var))
1080       return emitOpError()
1081              << "operand used in both private and lastprivate clauses";
1082   }
1083 
1084   if (allocate_vars().size() != allocators_vars().size())
1085     return emitError(
1086         "expected equal sizes for allocate and allocator variables");
1087 
1088   for (auto &inst : *region().begin()) {
1089     if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
1090       return emitOpError()
1091              << "expected omp.section op or terminator op inside region";
1092     }
1093   }
1094 
1095   return verifyReductionVarList(*this, reductions(), reduction_vars());
1096 }
1097 
1098 /// Parses an OpenMP Workshare Loop operation
1099 ///
1100 /// wsloop ::= `omp.wsloop` loop-control clause-list
1101 /// loop-control ::= `(` ssa-id-list `)` `:` type `=`  loop-bounds
1102 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` inclusive? steps
1103 /// steps := `step` `(`ssa-id-list`)`
1104 /// clause-list ::= clause clause-list | empty
1105 /// clause ::= private | firstprivate | lastprivate | linear | schedule |
1106 //             collapse | nowait | ordered | order | reduction
1107 ParseResult WsLoopOp::parse(OpAsmParser &parser, OperationState &result) {
1108   // Parse an opening `(` followed by induction variables followed by `)`
1109   SmallVector<OpAsmParser::OperandType> ivs;
1110   if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
1111                                      OpAsmParser::Delimiter::Paren))
1112     return failure();
1113 
1114   int numIVs = static_cast<int>(ivs.size());
1115   Type loopVarType;
1116   if (parser.parseColonType(loopVarType))
1117     return failure();
1118 
1119   // Parse loop bounds.
1120   SmallVector<OpAsmParser::OperandType> lower;
1121   if (parser.parseEqual() ||
1122       parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) ||
1123       parser.resolveOperands(lower, loopVarType, result.operands))
1124     return failure();
1125 
1126   SmallVector<OpAsmParser::OperandType> upper;
1127   if (parser.parseKeyword("to") ||
1128       parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) ||
1129       parser.resolveOperands(upper, loopVarType, result.operands))
1130     return failure();
1131 
1132   if (succeeded(parser.parseOptionalKeyword("inclusive"))) {
1133     auto attr = UnitAttr::get(parser.getBuilder().getContext());
1134     result.addAttribute("inclusive", attr);
1135   }
1136 
1137   // Parse step values.
1138   SmallVector<OpAsmParser::OperandType> steps;
1139   if (parser.parseKeyword("step") ||
1140       parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) ||
1141       parser.resolveOperands(steps, loopVarType, result.operands))
1142     return failure();
1143 
1144   SmallVector<ClauseType> clauses = {
1145       privateClause,   firstprivateClause, lastprivateClause, linearClause,
1146       reductionClause, collapseClause,     orderClause,       orderedClause,
1147       nowaitClause,    scheduleClause};
1148   SmallVector<int> segments{numIVs, numIVs, numIVs};
1149   if (failed(parseClauses(parser, result, clauses, segments)))
1150     return failure();
1151 
1152   result.addAttribute("operand_segment_sizes",
1153                       parser.getBuilder().getI32VectorAttr(segments));
1154 
1155   // Now parse the body.
1156   Region *body = result.addRegion();
1157   SmallVector<Type> ivTypes(numIVs, loopVarType);
1158   SmallVector<OpAsmParser::OperandType> blockArgs(ivs);
1159   if (parser.parseRegion(*body, blockArgs, ivTypes))
1160     return failure();
1161   return success();
1162 }
1163 
1164 void WsLoopOp::print(OpAsmPrinter &p) {
1165   auto args = getRegion().front().getArguments();
1166   p << " (" << args << ") : " << args[0].getType() << " = (" << lowerBound()
1167     << ") to (" << upperBound() << ") ";
1168   if (inclusive()) {
1169     p << "inclusive ";
1170   }
1171   p << "step (" << step() << ") ";
1172 
1173   printDataVars(p, private_vars(), "private");
1174   printDataVars(p, firstprivate_vars(), "firstprivate");
1175   printDataVars(p, lastprivate_vars(), "lastprivate");
1176 
1177   if (!linear_vars().empty())
1178     printLinearClause(p, linear_vars(), linear_step_vars());
1179 
1180   if (auto sched = schedule_val())
1181     printScheduleClause(p, sched.getValue(), schedule_modifier(),
1182                         simd_modifier(), schedule_chunk_var());
1183 
1184   if (auto collapse = collapse_val())
1185     p << "collapse(" << collapse << ") ";
1186 
1187   if (nowait())
1188     p << "nowait ";
1189 
1190   if (auto ordered = ordered_val())
1191     p << "ordered(" << ordered << ") ";
1192 
1193   if (auto order = order_val())
1194     p << "order(" << stringifyClauseOrderKind(*order) << ") ";
1195 
1196   if (!reduction_vars().empty())
1197     printReductionVarList(p, reductions(), reduction_vars());
1198 
1199   p << ' ';
1200   p.printRegion(region(), /*printEntryBlockArgs=*/false);
1201 }
1202 
1203 //===----------------------------------------------------------------------===//
1204 // ReductionOp
1205 //===----------------------------------------------------------------------===//
1206 
1207 static ParseResult parseAtomicReductionRegion(OpAsmParser &parser,
1208                                               Region &region) {
1209   if (parser.parseOptionalKeyword("atomic"))
1210     return success();
1211   return parser.parseRegion(region);
1212 }
1213 
1214 static void printAtomicReductionRegion(OpAsmPrinter &printer,
1215                                        ReductionDeclareOp op, Region &region) {
1216   if (region.empty())
1217     return;
1218   printer << "atomic ";
1219   printer.printRegion(region);
1220 }
1221 
1222 LogicalResult ReductionDeclareOp::verify() {
1223   if (initializerRegion().empty())
1224     return emitOpError() << "expects non-empty initializer region";
1225   Block &initializerEntryBlock = initializerRegion().front();
1226   if (initializerEntryBlock.getNumArguments() != 1 ||
1227       initializerEntryBlock.getArgument(0).getType() != type()) {
1228     return emitOpError() << "expects initializer region with one argument "
1229                             "of the reduction type";
1230   }
1231 
1232   for (YieldOp yieldOp : initializerRegion().getOps<YieldOp>()) {
1233     if (yieldOp.results().size() != 1 ||
1234         yieldOp.results().getTypes()[0] != type())
1235       return emitOpError() << "expects initializer region to yield a value "
1236                               "of the reduction type";
1237   }
1238 
1239   if (reductionRegion().empty())
1240     return emitOpError() << "expects non-empty reduction region";
1241   Block &reductionEntryBlock = reductionRegion().front();
1242   if (reductionEntryBlock.getNumArguments() != 2 ||
1243       reductionEntryBlock.getArgumentTypes()[0] !=
1244           reductionEntryBlock.getArgumentTypes()[1] ||
1245       reductionEntryBlock.getArgumentTypes()[0] != type())
1246     return emitOpError() << "expects reduction region with two arguments of "
1247                             "the reduction type";
1248   for (YieldOp yieldOp : reductionRegion().getOps<YieldOp>()) {
1249     if (yieldOp.results().size() != 1 ||
1250         yieldOp.results().getTypes()[0] != type())
1251       return emitOpError() << "expects reduction region to yield a value "
1252                               "of the reduction type";
1253   }
1254 
1255   if (atomicReductionRegion().empty())
1256     return success();
1257 
1258   Block &atomicReductionEntryBlock = atomicReductionRegion().front();
1259   if (atomicReductionEntryBlock.getNumArguments() != 2 ||
1260       atomicReductionEntryBlock.getArgumentTypes()[0] !=
1261           atomicReductionEntryBlock.getArgumentTypes()[1])
1262     return emitOpError() << "expects atomic reduction region with two "
1263                             "arguments of the same type";
1264   auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0]
1265                      .dyn_cast<PointerLikeType>();
1266   if (!ptrType || ptrType.getElementType() != type())
1267     return emitOpError() << "expects atomic reduction region arguments to "
1268                             "be accumulators containing the reduction type";
1269   return success();
1270 }
1271 
1272 LogicalResult ReductionOp::verify() {
1273   // TODO: generalize this to an op interface when there is more than one op
1274   // that supports reductions.
1275   auto container = (*this)->getParentOfType<WsLoopOp>();
1276   for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i)
1277     if (container.reduction_vars()[i] == accumulator())
1278       return success();
1279 
1280   return emitOpError() << "the accumulator is not used by the parent";
1281 }
1282 
1283 //===----------------------------------------------------------------------===//
1284 // WsLoopOp
1285 //===----------------------------------------------------------------------===//
1286 
1287 void WsLoopOp::build(OpBuilder &builder, OperationState &state,
1288                      ValueRange lowerBound, ValueRange upperBound,
1289                      ValueRange step, ArrayRef<NamedAttribute> attributes) {
1290   build(builder, state, TypeRange(), lowerBound, upperBound, step,
1291         /*privateVars=*/ValueRange(),
1292         /*firstprivateVars=*/ValueRange(), /*lastprivate_vars=*/ValueRange(),
1293         /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
1294         /*reduction_vars=*/ValueRange(), /*schedule_val=*/nullptr,
1295         /*schedule_chunk_var=*/nullptr, /*collapse_val=*/nullptr,
1296         /*nowait=*/nullptr, /*ordered_val=*/nullptr, /*order_val=*/nullptr,
1297         /*inclusive=*/nullptr, /*buildBody=*/false);
1298   state.addAttributes(attributes);
1299 }
1300 
1301 void WsLoopOp::build(OpBuilder &, OperationState &state, TypeRange resultTypes,
1302                      ValueRange operands, ArrayRef<NamedAttribute> attributes) {
1303   state.addOperands(operands);
1304   state.addAttributes(attributes);
1305   (void)state.addRegion();
1306   assert(resultTypes.empty() && "mismatched number of return types");
1307   state.addTypes(resultTypes);
1308 }
1309 
1310 void WsLoopOp::build(OpBuilder &builder, OperationState &result,
1311                      TypeRange typeRange, ValueRange lowerBounds,
1312                      ValueRange upperBounds, ValueRange steps,
1313                      ValueRange privateVars, ValueRange firstprivateVars,
1314                      ValueRange lastprivateVars, ValueRange linearVars,
1315                      ValueRange linearStepVars, ValueRange reductionVars,
1316                      StringAttr scheduleVal, Value scheduleChunkVar,
1317                      IntegerAttr collapseVal, UnitAttr nowait,
1318                      IntegerAttr orderedVal, StringAttr orderVal,
1319                      UnitAttr inclusive, bool buildBody) {
1320   result.addOperands(lowerBounds);
1321   result.addOperands(upperBounds);
1322   result.addOperands(steps);
1323   result.addOperands(privateVars);
1324   result.addOperands(firstprivateVars);
1325   result.addOperands(linearVars);
1326   result.addOperands(linearStepVars);
1327   if (scheduleChunkVar)
1328     result.addOperands(scheduleChunkVar);
1329 
1330   if (scheduleVal)
1331     result.addAttribute("schedule_val", scheduleVal);
1332   if (collapseVal)
1333     result.addAttribute("collapse_val", collapseVal);
1334   if (nowait)
1335     result.addAttribute("nowait", nowait);
1336   if (orderedVal)
1337     result.addAttribute("ordered_val", orderedVal);
1338   if (orderVal)
1339     result.addAttribute("order", orderVal);
1340   if (inclusive)
1341     result.addAttribute("inclusive", inclusive);
1342   result.addAttribute(
1343       WsLoopOp::getOperandSegmentSizeAttr(),
1344       builder.getI32VectorAttr(
1345           {static_cast<int32_t>(lowerBounds.size()),
1346            static_cast<int32_t>(upperBounds.size()),
1347            static_cast<int32_t>(steps.size()),
1348            static_cast<int32_t>(privateVars.size()),
1349            static_cast<int32_t>(firstprivateVars.size()),
1350            static_cast<int32_t>(lastprivateVars.size()),
1351            static_cast<int32_t>(linearVars.size()),
1352            static_cast<int32_t>(linearStepVars.size()),
1353            static_cast<int32_t>(reductionVars.size()),
1354            static_cast<int32_t>(scheduleChunkVar != nullptr ? 1 : 0)}));
1355 
1356   Region *bodyRegion = result.addRegion();
1357   if (buildBody) {
1358     OpBuilder::InsertionGuard guard(builder);
1359     unsigned numIVs = steps.size();
1360     SmallVector<Type, 8> argTypes(numIVs, steps.getType().front());
1361     SmallVector<Location, 8> argLocs(numIVs, result.location);
1362     builder.createBlock(bodyRegion, {}, argTypes, argLocs);
1363   }
1364 }
1365 
1366 LogicalResult WsLoopOp::verify() {
1367   return verifyReductionVarList(*this, reductions(), reduction_vars());
1368 }
1369 
1370 //===----------------------------------------------------------------------===//
1371 // Verifier for critical construct (2.17.1)
1372 //===----------------------------------------------------------------------===//
1373 
1374 LogicalResult CriticalDeclareOp::verify() {
1375   return verifySynchronizationHint(*this, hint());
1376 }
1377 
1378 LogicalResult CriticalOp::verify() {
1379   if (nameAttr()) {
1380     SymbolRefAttr symbolRef = nameAttr();
1381     auto decl = SymbolTable::lookupNearestSymbolFrom<CriticalDeclareOp>(
1382         *this, symbolRef);
1383     if (!decl) {
1384       return emitOpError() << "expected symbol reference " << symbolRef
1385                            << " to point to a critical declaration";
1386     }
1387   }
1388 
1389   return success();
1390 }
1391 
1392 //===----------------------------------------------------------------------===//
1393 // Verifier for ordered construct
1394 //===----------------------------------------------------------------------===//
1395 
1396 LogicalResult OrderedOp::verify() {
1397   auto container = (*this)->getParentOfType<WsLoopOp>();
1398   if (!container || !container.ordered_valAttr() ||
1399       container.ordered_valAttr().getInt() == 0)
1400     return emitOpError() << "ordered depend directive must be closely "
1401                          << "nested inside a worksharing-loop with ordered "
1402                          << "clause with parameter present";
1403 
1404   if (container.ordered_valAttr().getInt() !=
1405       (int64_t)num_loops_val().getValue())
1406     return emitOpError() << "number of variables in depend clause does not "
1407                          << "match number of iteration variables in the "
1408                          << "doacross loop";
1409 
1410   return success();
1411 }
1412 
1413 LogicalResult OrderedRegionOp::verify() {
1414   // TODO: The code generation for ordered simd directive is not supported yet.
1415   if (simd())
1416     return failure();
1417 
1418   if (auto container = (*this)->getParentOfType<WsLoopOp>()) {
1419     if (!container.ordered_valAttr() ||
1420         container.ordered_valAttr().getInt() != 0)
1421       return emitOpError() << "ordered region must be closely nested inside "
1422                            << "a worksharing-loop region with an ordered "
1423                            << "clause without parameter present";
1424   }
1425 
1426   return success();
1427 }
1428 
1429 //===----------------------------------------------------------------------===//
1430 // AtomicReadOp
1431 //===----------------------------------------------------------------------===//
1432 
1433 /// Parser for AtomicReadOp
1434 ///
1435 /// operation ::= `omp.atomic.read` atomic-clause-list address `->` result-type
1436 /// address ::= operand `:` type
1437 ParseResult AtomicReadOp::parse(OpAsmParser &parser, OperationState &result) {
1438   OpAsmParser::OperandType x, v;
1439   Type addressType;
1440   SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
1441   SmallVector<int> segments;
1442 
1443   if (parser.parseOperand(v) || parser.parseEqual() || parser.parseOperand(x) ||
1444       parseClauses(parser, result, clauses, segments) ||
1445       parser.parseColonType(addressType) ||
1446       parser.resolveOperand(x, addressType, result.operands) ||
1447       parser.resolveOperand(v, addressType, result.operands))
1448     return failure();
1449 
1450   return success();
1451 }
1452 
1453 void AtomicReadOp::print(OpAsmPrinter &p) {
1454   p << " " << v() << " = " << x() << " ";
1455   if (auto mo = memory_order())
1456     p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") ";
1457   if (hintAttr())
1458     printSynchronizationHint(p << " ", *this, hintAttr());
1459   p << ": " << x().getType();
1460 }
1461 
1462 /// Verifier for AtomicReadOp
1463 LogicalResult AtomicReadOp::verify() {
1464   if (auto mo = memory_order()) {
1465     if (*mo == ClauseMemoryOrderKind::acq_rel ||
1466         *mo == ClauseMemoryOrderKind::release) {
1467       return emitError(
1468           "memory-order must not be acq_rel or release for atomic reads");
1469     }
1470   }
1471   if (x() == v())
1472     return emitError(
1473         "read and write must not be to the same location for atomic reads");
1474   return verifySynchronizationHint(*this, hint());
1475 }
1476 
1477 //===----------------------------------------------------------------------===//
1478 // AtomicWriteOp
1479 //===----------------------------------------------------------------------===//
1480 
1481 /// Parser for AtomicWriteOp
1482 ///
1483 /// operation ::= `omp.atomic.write` atomic-clause-list operands
1484 /// operands ::= address `,` value
1485 /// address ::= operand `:` type
1486 /// value ::= operand `:` type
1487 ParseResult AtomicWriteOp::parse(OpAsmParser &parser, OperationState &result) {
1488   OpAsmParser::OperandType address, value;
1489   Type addrType, valueType;
1490   SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
1491   SmallVector<int> segments;
1492 
1493   if (parser.parseOperand(address) || parser.parseEqual() ||
1494       parser.parseOperand(value) ||
1495       parseClauses(parser, result, clauses, segments) ||
1496       parser.parseColonType(addrType) || parser.parseComma() ||
1497       parser.parseType(valueType) ||
1498       parser.resolveOperand(address, addrType, result.operands) ||
1499       parser.resolveOperand(value, valueType, result.operands))
1500     return failure();
1501   return success();
1502 }
1503 
1504 void AtomicWriteOp::print(OpAsmPrinter &p) {
1505   p << " " << address() << " = " << value() << " ";
1506   if (auto mo = memory_order())
1507     p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") ";
1508   if (hintAttr())
1509     printSynchronizationHint(p, *this, hintAttr());
1510   p << ": " << address().getType() << ", " << value().getType();
1511 }
1512 
1513 /// Verifier for AtomicWriteOp
1514 LogicalResult AtomicWriteOp::verify() {
1515   if (auto mo = memory_order()) {
1516     if (*mo == ClauseMemoryOrderKind::acq_rel ||
1517         *mo == ClauseMemoryOrderKind::acquire) {
1518       return emitError(
1519           "memory-order must not be acq_rel or acquire for atomic writes");
1520     }
1521   }
1522   return verifySynchronizationHint(*this, hint());
1523 }
1524 
1525 //===----------------------------------------------------------------------===//
1526 // AtomicUpdateOp
1527 //===----------------------------------------------------------------------===//
1528 
1529 /// Parser for AtomicUpdateOp
1530 ///
1531 /// operation ::= `omp.atomic.update` atomic-clause-list ssa-id-and-type region
1532 ParseResult AtomicUpdateOp::parse(OpAsmParser &parser, OperationState &result) {
1533   SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
1534   SmallVector<int> segments;
1535   OpAsmParser::OperandType x, expr;
1536   Type xType;
1537 
1538   if (parseClauses(parser, result, clauses, segments) ||
1539       parser.parseOperand(x) || parser.parseColon() ||
1540       parser.parseType(xType) ||
1541       parser.resolveOperand(x, xType, result.operands) ||
1542       parser.parseRegion(*result.addRegion()))
1543     return failure();
1544   return success();
1545 }
1546 
1547 void AtomicUpdateOp::print(OpAsmPrinter &p) {
1548   p << " ";
1549   if (auto mo = memory_order())
1550     p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") ";
1551   if (hintAttr())
1552     printSynchronizationHint(p, *this, hintAttr());
1553   p << x() << " : " << x().getType();
1554   p.printRegion(region());
1555 }
1556 
1557 /// Verifier for AtomicUpdateOp
1558 LogicalResult AtomicUpdateOp::verify() {
1559   if (auto mo = memory_order()) {
1560     if (*mo == ClauseMemoryOrderKind::acq_rel ||
1561         *mo == ClauseMemoryOrderKind::acquire) {
1562       return emitError(
1563           "memory-order must not be acq_rel or acquire for atomic updates");
1564     }
1565   }
1566 
1567   if (region().getNumArguments() != 1)
1568     return emitError("the region must accept exactly one argument");
1569 
1570   if (x().getType().cast<PointerLikeType>().getElementType() !=
1571       region().getArgument(0).getType()) {
1572     return emitError("the type of the operand must be a pointer type whose "
1573                      "element type is the same as that of the region argument");
1574   }
1575 
1576   YieldOp yieldOp = *region().getOps<YieldOp>().begin();
1577 
1578   if (yieldOp.results().size() != 1)
1579     return emitError("only updated value must be returned");
1580   if (yieldOp.results().front().getType() != region().getArgument(0).getType())
1581     return emitError("input and yielded value must have the same type");
1582   return success();
1583 }
1584 
1585 //===----------------------------------------------------------------------===//
1586 // AtomicCaptureOp
1587 //===----------------------------------------------------------------------===//
1588 
1589 ParseResult AtomicCaptureOp::parse(OpAsmParser &parser,
1590                                    OperationState &result) {
1591   SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
1592   SmallVector<int> segments;
1593   if (parseClauses(parser, result, clauses, segments) ||
1594       parser.parseRegion(*result.addRegion()))
1595     return failure();
1596   return success();
1597 }
1598 
1599 void AtomicCaptureOp::print(OpAsmPrinter &p) {
1600   if (memory_order())
1601     p << "memory_order(" << memory_order() << ") ";
1602   if (hintAttr())
1603     printSynchronizationHint(p, *this, hintAttr());
1604   p.printRegion(region());
1605 }
1606 
1607 /// Verifier for AtomicCaptureOp
1608 LogicalResult AtomicCaptureOp::verify() {
1609   Block::OpListType &ops = region().front().getOperations();
1610   if (ops.size() != 3)
1611     return emitError()
1612            << "expected three operations in omp.atomic.capture region (one "
1613               "terminator, and two atomic ops)";
1614   auto &firstOp = ops.front();
1615   auto &secondOp = *ops.getNextNode(firstOp);
1616   auto firstReadStmt = dyn_cast<AtomicReadOp>(firstOp);
1617   auto firstUpdateStmt = dyn_cast<AtomicUpdateOp>(firstOp);
1618   auto secondReadStmt = dyn_cast<AtomicReadOp>(secondOp);
1619   auto secondUpdateStmt = dyn_cast<AtomicUpdateOp>(secondOp);
1620   auto secondWriteStmt = dyn_cast<AtomicWriteOp>(secondOp);
1621 
1622   if (!((firstUpdateStmt && secondReadStmt) ||
1623         (firstReadStmt && secondUpdateStmt) ||
1624         (firstReadStmt && secondWriteStmt)))
1625     return ops.front().emitError()
1626            << "invalid sequence of operations in the capture region";
1627   if (firstUpdateStmt && secondReadStmt &&
1628       firstUpdateStmt.x() != secondReadStmt.x())
1629     return firstUpdateStmt.emitError()
1630            << "updated variable in omp.atomic.update must be captured in "
1631               "second operation";
1632   if (firstReadStmt && secondUpdateStmt &&
1633       firstReadStmt.x() != secondUpdateStmt.x())
1634     return firstReadStmt.emitError()
1635            << "captured variable in omp.atomic.read must be updated in second "
1636               "operation";
1637   if (firstReadStmt && secondWriteStmt &&
1638       firstReadStmt.x() != secondWriteStmt.address())
1639     return firstReadStmt.emitError()
1640            << "captured variable in omp.atomic.read must be updated in "
1641               "second operation";
1642   return success();
1643 }
1644 
1645 #define GET_ATTRDEF_CLASSES
1646 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
1647 
1648 #define GET_OP_CLASSES
1649 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
1650