1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the OpenMP dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
15 #include "mlir/IR/Attributes.h"
16 #include "mlir/IR/DialectImplementation.h"
17 #include "mlir/IR/OpImplementation.h"
18 #include "mlir/IR/OperationSupport.h"
19 
20 #include "llvm/ADT/BitVector.h"
21 #include "llvm/ADT/SmallString.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/ADT/StringSwitch.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include <cstddef>
27 
28 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
29 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
30 #include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
31 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
32 
33 using namespace mlir;
34 using namespace mlir::omp;
35 
36 namespace {
37 /// Model for pointer-like types that already provide a `getElementType` method.
38 template <typename T>
39 struct PointerLikeModel
40     : public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> {
getElementType__anon1411d50b0111::PointerLikeModel41   Type getElementType(Type pointer) const {
42     return pointer.cast<T>().getElementType();
43   }
44 };
45 } // namespace
46 
initialize()47 void OpenMPDialect::initialize() {
48   addOperations<
49 #define GET_OP_LIST
50 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
51       >();
52   addAttributes<
53 #define GET_ATTRDEF_LIST
54 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
55       >();
56 
57   LLVM::LLVMPointerType::attachInterface<
58       PointerLikeModel<LLVM::LLVMPointerType>>(*getContext());
59   MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext());
60 }
61 
62 //===----------------------------------------------------------------------===//
63 // Parser and printer for Allocate Clause
64 //===----------------------------------------------------------------------===//
65 
66 /// Parse an allocate clause with allocators and a list of operands with types.
67 ///
68 /// allocate-operand-list :: = allocate-operand |
69 ///                            allocator-operand `,` allocate-operand-list
70 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
71 /// ssa-id-and-type ::= ssa-id `:` type
parseAllocateAndAllocator(OpAsmParser & parser,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & operandsAllocate,SmallVectorImpl<Type> & typesAllocate,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & operandsAllocator,SmallVectorImpl<Type> & typesAllocator)72 static ParseResult parseAllocateAndAllocator(
73     OpAsmParser &parser,
74     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operandsAllocate,
75     SmallVectorImpl<Type> &typesAllocate,
76     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operandsAllocator,
77     SmallVectorImpl<Type> &typesAllocator) {
78 
79   return parser.parseCommaSeparatedList([&]() {
80     OpAsmParser::UnresolvedOperand operand;
81     Type type;
82     if (parser.parseOperand(operand) || parser.parseColonType(type))
83       return failure();
84     operandsAllocator.push_back(operand);
85     typesAllocator.push_back(type);
86     if (parser.parseArrow())
87       return failure();
88     if (parser.parseOperand(operand) || parser.parseColonType(type))
89       return failure();
90 
91     operandsAllocate.push_back(operand);
92     typesAllocate.push_back(type);
93     return success();
94   });
95 }
96 
97 /// Print allocate clause
printAllocateAndAllocator(OpAsmPrinter & p,Operation * op,OperandRange varsAllocate,TypeRange typesAllocate,OperandRange varsAllocator,TypeRange typesAllocator)98 static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op,
99                                       OperandRange varsAllocate,
100                                       TypeRange typesAllocate,
101                                       OperandRange varsAllocator,
102                                       TypeRange typesAllocator) {
103   for (unsigned i = 0; i < varsAllocate.size(); ++i) {
104     std::string separator = i == varsAllocate.size() - 1 ? "" : ", ";
105     p << varsAllocator[i] << " : " << typesAllocator[i] << " -> ";
106     p << varsAllocate[i] << " : " << typesAllocate[i] << separator;
107   }
108 }
109 
110 //===----------------------------------------------------------------------===//
111 // Parser and printer for a clause attribute (StringEnumAttr)
112 //===----------------------------------------------------------------------===//
113 
114 template <typename ClauseAttr>
parseClauseAttr(AsmParser & parser,ClauseAttr & attr)115 static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) {
116   using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
117   StringRef enumStr;
118   SMLoc loc = parser.getCurrentLocation();
119   if (parser.parseKeyword(&enumStr))
120     return failure();
121   if (Optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
122     attr = ClauseAttr::get(parser.getContext(), *enumValue);
123     return success();
124   }
125   return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
126 }
127 
128 template <typename ClauseAttr>
printClauseAttr(OpAsmPrinter & p,Operation * op,ClauseAttr attr)129 void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
130   p << stringifyEnum(attr.getValue());
131 }
132 
133 //===----------------------------------------------------------------------===//
134 // Parser and printer for Linear Clause
135 //===----------------------------------------------------------------------===//
136 
137 /// linear ::= `linear` `(` linear-list `)`
138 /// linear-list := linear-val | linear-val linear-list
139 /// linear-val := ssa-id-and-type `=` ssa-id-and-type
140 static ParseResult
parseLinearClause(OpAsmParser & parser,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & vars,SmallVectorImpl<Type> & types,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & stepVars)141 parseLinearClause(OpAsmParser &parser,
142                   SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
143                   SmallVectorImpl<Type> &types,
144                   SmallVectorImpl<OpAsmParser::UnresolvedOperand> &stepVars) {
145   return parser.parseCommaSeparatedList([&]() {
146     OpAsmParser::UnresolvedOperand var;
147     Type type;
148     OpAsmParser::UnresolvedOperand stepVar;
149     if (parser.parseOperand(var) || parser.parseEqual() ||
150         parser.parseOperand(stepVar) || parser.parseColonType(type))
151       return failure();
152 
153     vars.push_back(var);
154     types.push_back(type);
155     stepVars.push_back(stepVar);
156     return success();
157   });
158 }
159 
160 /// Print Linear Clause
printLinearClause(OpAsmPrinter & p,Operation * op,ValueRange linearVars,TypeRange linearVarTypes,ValueRange linearStepVars)161 static void printLinearClause(OpAsmPrinter &p, Operation *op,
162                               ValueRange linearVars, TypeRange linearVarTypes,
163                               ValueRange linearStepVars) {
164   size_t linearVarsSize = linearVars.size();
165   for (unsigned i = 0; i < linearVarsSize; ++i) {
166     std::string separator = i == linearVarsSize - 1 ? "" : ", ";
167     p << linearVars[i];
168     if (linearStepVars.size() > i)
169       p << " = " << linearStepVars[i];
170     p << " : " << linearVars[i].getType() << separator;
171   }
172 }
173 
174 //===----------------------------------------------------------------------===//
175 // Parser, printer and verifier for Schedule Clause
176 //===----------------------------------------------------------------------===//
177 
178 static ParseResult
verifyScheduleModifiers(OpAsmParser & parser,SmallVectorImpl<SmallString<12>> & modifiers)179 verifyScheduleModifiers(OpAsmParser &parser,
180                         SmallVectorImpl<SmallString<12>> &modifiers) {
181   if (modifiers.size() > 2)
182     return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)";
183   for (const auto &mod : modifiers) {
184     // Translate the string. If it has no value, then it was not a valid
185     // modifier!
186     auto symbol = symbolizeScheduleModifier(mod);
187     if (!symbol)
188       return parser.emitError(parser.getNameLoc())
189              << " unknown modifier type: " << mod;
190   }
191 
192   // If we have one modifier that is "simd", then stick a "none" modiifer in
193   // index 0.
194   if (modifiers.size() == 1) {
195     if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
196       modifiers.push_back(modifiers[0]);
197       modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
198     }
199   } else if (modifiers.size() == 2) {
200     // If there are two modifier:
201     // First modifier should not be simd, second one should be simd
202     if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
203         symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
204       return parser.emitError(parser.getNameLoc())
205              << " incorrect modifier order";
206   }
207   return success();
208 }
209 
210 /// schedule ::= `schedule` `(` sched-list `)`
211 /// sched-list ::= sched-val | sched-val sched-list |
212 ///                sched-val `,` sched-modifier
213 /// sched-val ::= sched-with-chunk | sched-wo-chunk
214 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
215 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
216 /// sched-wo-chunk ::=  `auto` | `runtime`
217 /// sched-modifier ::=  sched-mod-val | sched-mod-val `,` sched-mod-val
218 /// sched-mod-val ::=  `monotonic` | `nonmonotonic` | `simd` | `none`
parseScheduleClause(OpAsmParser & parser,ClauseScheduleKindAttr & scheduleAttr,ScheduleModifierAttr & scheduleModifier,UnitAttr & simdModifier,Optional<OpAsmParser::UnresolvedOperand> & chunkSize,Type & chunkType)219 static ParseResult parseScheduleClause(
220     OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr,
221     ScheduleModifierAttr &scheduleModifier, UnitAttr &simdModifier,
222     Optional<OpAsmParser::UnresolvedOperand> &chunkSize, Type &chunkType) {
223   StringRef keyword;
224   if (parser.parseKeyword(&keyword))
225     return failure();
226   llvm::Optional<mlir::omp::ClauseScheduleKind> schedule =
227       symbolizeClauseScheduleKind(keyword);
228   if (!schedule)
229     return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
230 
231   scheduleAttr = ClauseScheduleKindAttr::get(parser.getContext(), *schedule);
232   switch (*schedule) {
233   case ClauseScheduleKind::Static:
234   case ClauseScheduleKind::Dynamic:
235   case ClauseScheduleKind::Guided:
236     if (succeeded(parser.parseOptionalEqual())) {
237       chunkSize = OpAsmParser::UnresolvedOperand{};
238       if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType))
239         return failure();
240     } else {
241       chunkSize = llvm::NoneType::None;
242     }
243     break;
244   case ClauseScheduleKind::Auto:
245   case ClauseScheduleKind::Runtime:
246     chunkSize = llvm::NoneType::None;
247   }
248 
249   // If there is a comma, we have one or more modifiers..
250   SmallVector<SmallString<12>> modifiers;
251   while (succeeded(parser.parseOptionalComma())) {
252     StringRef mod;
253     if (parser.parseKeyword(&mod))
254       return failure();
255     modifiers.push_back(mod);
256   }
257 
258   if (verifyScheduleModifiers(parser, modifiers))
259     return failure();
260 
261   if (!modifiers.empty()) {
262     SMLoc loc = parser.getCurrentLocation();
263     if (Optional<ScheduleModifier> mod =
264             symbolizeScheduleModifier(modifiers[0])) {
265       scheduleModifier = ScheduleModifierAttr::get(parser.getContext(), *mod);
266     } else {
267       return parser.emitError(loc, "invalid schedule modifier");
268     }
269     // Only SIMD attribute is allowed here!
270     if (modifiers.size() > 1) {
271       assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
272       simdModifier = UnitAttr::get(parser.getBuilder().getContext());
273     }
274   }
275 
276   return success();
277 }
278 
279 /// Print schedule clause
printScheduleClause(OpAsmPrinter & p,Operation * op,ClauseScheduleKindAttr schedAttr,ScheduleModifierAttr modifier,UnitAttr simd,Value scheduleChunkVar,Type scheduleChunkType)280 static void printScheduleClause(OpAsmPrinter &p, Operation *op,
281                                 ClauseScheduleKindAttr schedAttr,
282                                 ScheduleModifierAttr modifier, UnitAttr simd,
283                                 Value scheduleChunkVar,
284                                 Type scheduleChunkType) {
285   p << stringifyClauseScheduleKind(schedAttr.getValue());
286   if (scheduleChunkVar)
287     p << " = " << scheduleChunkVar << " : " << scheduleChunkVar.getType();
288   if (modifier)
289     p << ", " << stringifyScheduleModifier(modifier.getValue());
290   if (simd)
291     p << ", simd";
292 }
293 
294 //===----------------------------------------------------------------------===//
295 // Parser, printer and verifier for ReductionVarList
296 //===----------------------------------------------------------------------===//
297 
298 /// reduction-entry-list ::= reduction-entry
299 ///                        | reduction-entry-list `,` reduction-entry
300 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type
301 static ParseResult
parseReductionVarList(OpAsmParser & parser,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & operands,SmallVectorImpl<Type> & types,ArrayAttr & redcuctionSymbols)302 parseReductionVarList(OpAsmParser &parser,
303                       SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
304                       SmallVectorImpl<Type> &types,
305                       ArrayAttr &redcuctionSymbols) {
306   SmallVector<SymbolRefAttr> reductionVec;
307   if (failed(parser.parseCommaSeparatedList([&]() {
308         if (parser.parseAttribute(reductionVec.emplace_back()) ||
309             parser.parseArrow() ||
310             parser.parseOperand(operands.emplace_back()) ||
311             parser.parseColonType(types.emplace_back()))
312           return failure();
313         return success();
314       })))
315     return failure();
316   SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
317   redcuctionSymbols = ArrayAttr::get(parser.getContext(), reductions);
318   return success();
319 }
320 
321 /// Print Reduction clause
printReductionVarList(OpAsmPrinter & p,Operation * op,OperandRange reductionVars,TypeRange reductionTypes,Optional<ArrayAttr> reductions)322 static void printReductionVarList(OpAsmPrinter &p, Operation *op,
323                                   OperandRange reductionVars,
324                                   TypeRange reductionTypes,
325                                   Optional<ArrayAttr> reductions) {
326   for (unsigned i = 0, e = reductions->size(); i < e; ++i) {
327     if (i != 0)
328       p << ", ";
329     p << (*reductions)[i] << " -> " << reductionVars[i] << " : "
330       << reductionVars[i].getType();
331   }
332 }
333 
334 /// Verifies Reduction Clause
verifyReductionVarList(Operation * op,Optional<ArrayAttr> reductions,OperandRange reductionVars)335 static LogicalResult verifyReductionVarList(Operation *op,
336                                             Optional<ArrayAttr> reductions,
337                                             OperandRange reductionVars) {
338   if (!reductionVars.empty()) {
339     if (!reductions || reductions->size() != reductionVars.size())
340       return op->emitOpError()
341              << "expected as many reduction symbol references "
342                 "as reduction variables";
343   } else {
344     if (reductions)
345       return op->emitOpError() << "unexpected reduction symbol references";
346     return success();
347   }
348 
349   // TODO: The followings should be done in
350   // SymbolUserOpInterface::verifySymbolUses.
351   DenseSet<Value> accumulators;
352   for (auto args : llvm::zip(reductionVars, *reductions)) {
353     Value accum = std::get<0>(args);
354 
355     if (!accumulators.insert(accum).second)
356       return op->emitOpError() << "accumulator variable used more than once";
357 
358     Type varType = accum.getType().cast<PointerLikeType>();
359     auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>();
360     auto decl =
361         SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef);
362     if (!decl)
363       return op->emitOpError() << "expected symbol reference " << symbolRef
364                                << " to point to a reduction declaration";
365 
366     if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
367       return op->emitOpError()
368              << "expected accumulator (" << varType
369              << ") to be the same type as reduction declaration ("
370              << decl.getAccumulatorType() << ")";
371   }
372 
373   return success();
374 }
375 
376 //===----------------------------------------------------------------------===//
377 // Parser, printer and verifier for Synchronization Hint (2.17.12)
378 //===----------------------------------------------------------------------===//
379 
380 /// Parses a Synchronization Hint clause. The value of hint is an integer
381 /// which is a combination of different hints from `omp_sync_hint_t`.
382 ///
383 /// hint-clause = `hint` `(` hint-value `)`
parseSynchronizationHint(OpAsmParser & parser,IntegerAttr & hintAttr)384 static ParseResult parseSynchronizationHint(OpAsmParser &parser,
385                                             IntegerAttr &hintAttr) {
386   StringRef hintKeyword;
387   int64_t hint = 0;
388   if (succeeded(parser.parseOptionalKeyword("none"))) {
389     hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
390     return success();
391   }
392   auto parseKeyword = [&]() -> ParseResult {
393     if (failed(parser.parseKeyword(&hintKeyword)))
394       return failure();
395     if (hintKeyword == "uncontended")
396       hint |= 1;
397     else if (hintKeyword == "contended")
398       hint |= 2;
399     else if (hintKeyword == "nonspeculative")
400       hint |= 4;
401     else if (hintKeyword == "speculative")
402       hint |= 8;
403     else
404       return parser.emitError(parser.getCurrentLocation())
405              << hintKeyword << " is not a valid hint";
406     return success();
407   };
408   if (parser.parseCommaSeparatedList(parseKeyword))
409     return failure();
410   hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
411   return success();
412 }
413 
414 /// Prints a Synchronization Hint clause
printSynchronizationHint(OpAsmPrinter & p,Operation * op,IntegerAttr hintAttr)415 static void printSynchronizationHint(OpAsmPrinter &p, Operation *op,
416                                      IntegerAttr hintAttr) {
417   int64_t hint = hintAttr.getInt();
418 
419   if (hint == 0) {
420     p << "none";
421     return;
422   }
423 
424   // Helper function to get n-th bit from the right end of `value`
425   auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
426 
427   bool uncontended = bitn(hint, 0);
428   bool contended = bitn(hint, 1);
429   bool nonspeculative = bitn(hint, 2);
430   bool speculative = bitn(hint, 3);
431 
432   SmallVector<StringRef> hints;
433   if (uncontended)
434     hints.push_back("uncontended");
435   if (contended)
436     hints.push_back("contended");
437   if (nonspeculative)
438     hints.push_back("nonspeculative");
439   if (speculative)
440     hints.push_back("speculative");
441 
442   llvm::interleaveComma(hints, p);
443 }
444 
445 /// Verifies a synchronization hint clause
verifySynchronizationHint(Operation * op,uint64_t hint)446 static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
447 
448   // Helper function to get n-th bit from the right end of `value`
449   auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
450 
451   bool uncontended = bitn(hint, 0);
452   bool contended = bitn(hint, 1);
453   bool nonspeculative = bitn(hint, 2);
454   bool speculative = bitn(hint, 3);
455 
456   if (uncontended && contended)
457     return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
458                                 "omp_sync_hint_contended cannot be combined";
459   if (nonspeculative && speculative)
460     return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
461                                 "omp_sync_hint_speculative cannot be combined.";
462   return success();
463 }
464 
465 //===----------------------------------------------------------------------===//
466 // ParallelOp
467 //===----------------------------------------------------------------------===//
468 
build(OpBuilder & builder,OperationState & state,ArrayRef<NamedAttribute> attributes)469 void ParallelOp::build(OpBuilder &builder, OperationState &state,
470                        ArrayRef<NamedAttribute> attributes) {
471   ParallelOp::build(
472       builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
473       /*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(),
474       /*reduction_vars=*/ValueRange(), /*reductions=*/nullptr,
475       /*proc_bind_val=*/nullptr);
476   state.addAttributes(attributes);
477 }
478 
verify()479 LogicalResult ParallelOp::verify() {
480   if (allocate_vars().size() != allocators_vars().size())
481     return emitError(
482         "expected equal sizes for allocate and allocator variables");
483   return verifyReductionVarList(*this, reductions(), reduction_vars());
484 }
485 
486 //===----------------------------------------------------------------------===//
487 // Verifier for SectionsOp
488 //===----------------------------------------------------------------------===//
489 
verify()490 LogicalResult SectionsOp::verify() {
491   if (allocate_vars().size() != allocators_vars().size())
492     return emitError(
493         "expected equal sizes for allocate and allocator variables");
494 
495   return verifyReductionVarList(*this, reductions(), reduction_vars());
496 }
497 
verifyRegions()498 LogicalResult SectionsOp::verifyRegions() {
499   for (auto &inst : *region().begin()) {
500     if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
501       return emitOpError()
502              << "expected omp.section op or terminator op inside region";
503     }
504   }
505 
506   return success();
507 }
508 
verify()509 LogicalResult SingleOp::verify() {
510   // Check for allocate clause restrictions
511   if (allocate_vars().size() != allocators_vars().size())
512     return emitError(
513         "expected equal sizes for allocate and allocator variables");
514 
515   return success();
516 }
517 
518 //===----------------------------------------------------------------------===//
519 // WsLoopOp
520 //===----------------------------------------------------------------------===//
521 
522 /// loop-control ::= `(` ssa-id-list `)` `:` type `=`  loop-bounds
523 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` inclusive? steps
524 /// steps := `step` `(`ssa-id-list`)`
525 ParseResult
parseLoopControl(OpAsmParser & parser,Region & region,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & lowerBound,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & upperBound,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & steps,SmallVectorImpl<Type> & loopVarTypes,UnitAttr & inclusive)526 parseLoopControl(OpAsmParser &parser, Region &region,
527                  SmallVectorImpl<OpAsmParser::UnresolvedOperand> &lowerBound,
528                  SmallVectorImpl<OpAsmParser::UnresolvedOperand> &upperBound,
529                  SmallVectorImpl<OpAsmParser::UnresolvedOperand> &steps,
530                  SmallVectorImpl<Type> &loopVarTypes, UnitAttr &inclusive) {
531   // Parse an opening `(` followed by induction variables followed by `)`
532   SmallVector<OpAsmParser::Argument> ivs;
533   Type loopVarType;
534   if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren) ||
535       parser.parseColonType(loopVarType) ||
536       // Parse loop bounds.
537       parser.parseEqual() ||
538       parser.parseOperandList(lowerBound, ivs.size(),
539                               OpAsmParser::Delimiter::Paren) ||
540       parser.parseKeyword("to") ||
541       parser.parseOperandList(upperBound, ivs.size(),
542                               OpAsmParser::Delimiter::Paren))
543     return failure();
544 
545   if (succeeded(parser.parseOptionalKeyword("inclusive")))
546     inclusive = UnitAttr::get(parser.getBuilder().getContext());
547 
548   // Parse step values.
549   if (parser.parseKeyword("step") ||
550       parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren))
551     return failure();
552 
553   // Now parse the body.
554   loopVarTypes = SmallVector<Type>(ivs.size(), loopVarType);
555   for (auto &iv : ivs)
556     iv.type = loopVarType;
557   return parser.parseRegion(region, ivs);
558 }
559 
printLoopControl(OpAsmPrinter & p,Operation * op,Region & region,ValueRange lowerBound,ValueRange upperBound,ValueRange steps,TypeRange loopVarTypes,UnitAttr inclusive)560 void printLoopControl(OpAsmPrinter &p, Operation *op, Region &region,
561                       ValueRange lowerBound, ValueRange upperBound,
562                       ValueRange steps, TypeRange loopVarTypes,
563                       UnitAttr inclusive) {
564   auto args = region.front().getArguments();
565   p << " (" << args << ") : " << args[0].getType() << " = (" << lowerBound
566     << ") to (" << upperBound << ") ";
567   if (inclusive)
568     p << "inclusive ";
569   p << "step (" << steps << ") ";
570   p.printRegion(region, /*printEntryBlockArgs=*/false);
571 }
572 
573 //===----------------------------------------------------------------------===//
574 // Verifier for Simd construct [2.9.3.1]
575 //===----------------------------------------------------------------------===//
576 
verify()577 LogicalResult SimdLoopOp::verify() {
578   if (this->lowerBound().empty()) {
579     return emitOpError() << "empty lowerbound for simd loop operation";
580   }
581   return success();
582 }
583 
584 //===----------------------------------------------------------------------===//
585 // ReductionOp
586 //===----------------------------------------------------------------------===//
587 
parseAtomicReductionRegion(OpAsmParser & parser,Region & region)588 static ParseResult parseAtomicReductionRegion(OpAsmParser &parser,
589                                               Region &region) {
590   if (parser.parseOptionalKeyword("atomic"))
591     return success();
592   return parser.parseRegion(region);
593 }
594 
printAtomicReductionRegion(OpAsmPrinter & printer,ReductionDeclareOp op,Region & region)595 static void printAtomicReductionRegion(OpAsmPrinter &printer,
596                                        ReductionDeclareOp op, Region &region) {
597   if (region.empty())
598     return;
599   printer << "atomic ";
600   printer.printRegion(region);
601 }
602 
verifyRegions()603 LogicalResult ReductionDeclareOp::verifyRegions() {
604   if (initializerRegion().empty())
605     return emitOpError() << "expects non-empty initializer region";
606   Block &initializerEntryBlock = initializerRegion().front();
607   if (initializerEntryBlock.getNumArguments() != 1 ||
608       initializerEntryBlock.getArgument(0).getType() != type()) {
609     return emitOpError() << "expects initializer region with one argument "
610                             "of the reduction type";
611   }
612 
613   for (YieldOp yieldOp : initializerRegion().getOps<YieldOp>()) {
614     if (yieldOp.results().size() != 1 ||
615         yieldOp.results().getTypes()[0] != type())
616       return emitOpError() << "expects initializer region to yield a value "
617                               "of the reduction type";
618   }
619 
620   if (reductionRegion().empty())
621     return emitOpError() << "expects non-empty reduction region";
622   Block &reductionEntryBlock = reductionRegion().front();
623   if (reductionEntryBlock.getNumArguments() != 2 ||
624       reductionEntryBlock.getArgumentTypes()[0] !=
625           reductionEntryBlock.getArgumentTypes()[1] ||
626       reductionEntryBlock.getArgumentTypes()[0] != type())
627     return emitOpError() << "expects reduction region with two arguments of "
628                             "the reduction type";
629   for (YieldOp yieldOp : reductionRegion().getOps<YieldOp>()) {
630     if (yieldOp.results().size() != 1 ||
631         yieldOp.results().getTypes()[0] != type())
632       return emitOpError() << "expects reduction region to yield a value "
633                               "of the reduction type";
634   }
635 
636   if (atomicReductionRegion().empty())
637     return success();
638 
639   Block &atomicReductionEntryBlock = atomicReductionRegion().front();
640   if (atomicReductionEntryBlock.getNumArguments() != 2 ||
641       atomicReductionEntryBlock.getArgumentTypes()[0] !=
642           atomicReductionEntryBlock.getArgumentTypes()[1])
643     return emitOpError() << "expects atomic reduction region with two "
644                             "arguments of the same type";
645   auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0]
646                      .dyn_cast<PointerLikeType>();
647   if (!ptrType || ptrType.getElementType() != type())
648     return emitOpError() << "expects atomic reduction region arguments to "
649                             "be accumulators containing the reduction type";
650   return success();
651 }
652 
verify()653 LogicalResult ReductionOp::verify() {
654   auto *op = (*this)->getParentWithTrait<ReductionClauseInterface::Trait>();
655   if (!op)
656     return emitOpError() << "must be used within an operation supporting "
657                             "reduction clause interface";
658   while (op) {
659     for (const auto &var :
660          cast<ReductionClauseInterface>(op).getReductionVars())
661       if (var == accumulator())
662         return success();
663     op = op->getParentWithTrait<ReductionClauseInterface::Trait>();
664   }
665   return emitOpError() << "the accumulator is not used by the parent";
666 }
667 
668 //===----------------------------------------------------------------------===//
669 // TaskOp
670 //===----------------------------------------------------------------------===//
verify()671 LogicalResult TaskOp::verify() {
672   return verifyReductionVarList(*this, in_reductions(), in_reduction_vars());
673 }
674 
675 //===----------------------------------------------------------------------===//
676 // TaskGroupOp
677 //===----------------------------------------------------------------------===//
verify()678 LogicalResult TaskGroupOp::verify() {
679   return verifyReductionVarList(*this, task_reductions(),
680                                 task_reduction_vars());
681 }
682 
683 //===----------------------------------------------------------------------===//
684 // TaskLoopOp
685 //===----------------------------------------------------------------------===//
getReductionVars()686 SmallVector<Value> TaskLoopOp::getReductionVars() {
687   SmallVector<Value> all_reduction_nvars(in_reduction_vars().begin(),
688                                          in_reduction_vars().end());
689   all_reduction_nvars.insert(all_reduction_nvars.end(),
690                              reduction_vars().begin(), reduction_vars().end());
691   return all_reduction_nvars;
692 }
693 
verify()694 LogicalResult TaskLoopOp::verify() {
695   if (allocate_vars().size() != allocators_vars().size())
696     return emitError(
697         "expected equal sizes for allocate and allocator variables");
698   if (failed(verifyReductionVarList(*this, reductions(), reduction_vars())) ||
699       failed(
700           verifyReductionVarList(*this, in_reductions(), in_reduction_vars())))
701     return failure();
702 
703   if (reduction_vars().size() > 0 && nogroup())
704     return emitError("if a reduction clause is present on the taskloop "
705                      "directive, the nogroup clause must not be specified");
706   for (auto var : reduction_vars()) {
707     if (llvm::is_contained(in_reduction_vars(), var))
708       return emitError("the same list item cannot appear in both a reduction "
709                        "and an in_reduction clause");
710   }
711 
712   if (grain_size() && num_tasks()) {
713     return emitError(
714         "the grainsize clause and num_tasks clause are mutually exclusive and "
715         "may not appear on the same taskloop directive");
716   }
717   return success();
718 }
719 
720 //===----------------------------------------------------------------------===//
721 // WsLoopOp
722 //===----------------------------------------------------------------------===//
723 
build(OpBuilder & builder,OperationState & state,ValueRange lowerBound,ValueRange upperBound,ValueRange step,ArrayRef<NamedAttribute> attributes)724 void WsLoopOp::build(OpBuilder &builder, OperationState &state,
725                      ValueRange lowerBound, ValueRange upperBound,
726                      ValueRange step, ArrayRef<NamedAttribute> attributes) {
727   build(builder, state, lowerBound, upperBound, step,
728         /*linear_vars=*/ValueRange(),
729         /*linear_step_vars=*/ValueRange(), /*reduction_vars=*/ValueRange(),
730         /*reductions=*/nullptr, /*schedule_val=*/nullptr,
731         /*schedule_chunk_var=*/nullptr, /*schedule_modifier=*/nullptr,
732         /*simd_modifier=*/false, /*nowait=*/false, /*ordered_val=*/nullptr,
733         /*order_val=*/nullptr, /*inclusive=*/false);
734   state.addAttributes(attributes);
735 }
736 
verify()737 LogicalResult WsLoopOp::verify() {
738   return verifyReductionVarList(*this, reductions(), reduction_vars());
739 }
740 
741 //===----------------------------------------------------------------------===//
742 // Verifier for critical construct (2.17.1)
743 //===----------------------------------------------------------------------===//
744 
verify()745 LogicalResult CriticalDeclareOp::verify() {
746   return verifySynchronizationHint(*this, hint_val());
747 }
748 
verifySymbolUses(SymbolTableCollection & symbolTable)749 LogicalResult CriticalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
750   if (nameAttr()) {
751     SymbolRefAttr symbolRef = nameAttr();
752     auto decl = symbolTable.lookupNearestSymbolFrom<CriticalDeclareOp>(
753         *this, symbolRef);
754     if (!decl) {
755       return emitOpError() << "expected symbol reference " << symbolRef
756                            << " to point to a critical declaration";
757     }
758   }
759 
760   return success();
761 }
762 
763 //===----------------------------------------------------------------------===//
764 // Verifier for ordered construct
765 //===----------------------------------------------------------------------===//
766 
verify()767 LogicalResult OrderedOp::verify() {
768   auto container = (*this)->getParentOfType<WsLoopOp>();
769   if (!container || !container.ordered_valAttr() ||
770       container.ordered_valAttr().getInt() == 0)
771     return emitOpError() << "ordered depend directive must be closely "
772                          << "nested inside a worksharing-loop with ordered "
773                          << "clause with parameter present";
774 
775   if (container.ordered_valAttr().getInt() != (int64_t)*num_loops_val())
776     return emitOpError() << "number of variables in depend clause does not "
777                          << "match number of iteration variables in the "
778                          << "doacross loop";
779 
780   return success();
781 }
782 
verify()783 LogicalResult OrderedRegionOp::verify() {
784   // TODO: The code generation for ordered simd directive is not supported yet.
785   if (simd())
786     return failure();
787 
788   if (auto container = (*this)->getParentOfType<WsLoopOp>()) {
789     if (!container.ordered_valAttr() ||
790         container.ordered_valAttr().getInt() != 0)
791       return emitOpError() << "ordered region must be closely nested inside "
792                            << "a worksharing-loop region with an ordered "
793                            << "clause without parameter present";
794   }
795 
796   return success();
797 }
798 
799 //===----------------------------------------------------------------------===//
800 // Verifier for AtomicReadOp
801 //===----------------------------------------------------------------------===//
802 
verify()803 LogicalResult AtomicReadOp::verify() {
804   if (auto mo = memory_order_val()) {
805     if (*mo == ClauseMemoryOrderKind::Acq_rel ||
806         *mo == ClauseMemoryOrderKind::Release) {
807       return emitError(
808           "memory-order must not be acq_rel or release for atomic reads");
809     }
810   }
811   if (x() == v())
812     return emitError(
813         "read and write must not be to the same location for atomic reads");
814   return verifySynchronizationHint(*this, hint_val());
815 }
816 
817 //===----------------------------------------------------------------------===//
818 // Verifier for AtomicWriteOp
819 //===----------------------------------------------------------------------===//
820 
verify()821 LogicalResult AtomicWriteOp::verify() {
822   if (auto mo = memory_order_val()) {
823     if (*mo == ClauseMemoryOrderKind::Acq_rel ||
824         *mo == ClauseMemoryOrderKind::Acquire) {
825       return emitError(
826           "memory-order must not be acq_rel or acquire for atomic writes");
827     }
828   }
829   if (address().getType().cast<PointerLikeType>().getElementType() !=
830       value().getType())
831     return emitError("address must dereference to value type");
832   return verifySynchronizationHint(*this, hint_val());
833 }
834 
835 //===----------------------------------------------------------------------===//
836 // Verifier for AtomicUpdateOp
837 //===----------------------------------------------------------------------===//
838 
verify()839 LogicalResult AtomicUpdateOp::verify() {
840   if (auto mo = memory_order_val()) {
841     if (*mo == ClauseMemoryOrderKind::Acq_rel ||
842         *mo == ClauseMemoryOrderKind::Acquire) {
843       return emitError(
844           "memory-order must not be acq_rel or acquire for atomic updates");
845     }
846   }
847 
848   if (x().getType().cast<PointerLikeType>().getElementType() !=
849       region().getArgument(0).getType()) {
850     return emitError("the type of the operand must be a pointer type whose "
851                      "element type is the same as that of the region argument");
852   }
853 
854   return verifySynchronizationHint(*this, hint_val());
855 }
856 
verifyRegions()857 LogicalResult AtomicUpdateOp::verifyRegions() {
858   if (region().getNumArguments() != 1)
859     return emitError("the region must accept exactly one argument");
860 
861   if (region().front().getOperations().size() < 2)
862     return emitError() << "the update region must have at least two operations "
863                           "(binop and terminator)";
864 
865   YieldOp yieldOp = *region().getOps<YieldOp>().begin();
866 
867   if (yieldOp.results().size() != 1)
868     return emitError("only updated value must be returned");
869   if (yieldOp.results().front().getType() != region().getArgument(0).getType())
870     return emitError("input and yielded value must have the same type");
871   return success();
872 }
873 
874 //===----------------------------------------------------------------------===//
875 // Verifier for AtomicCaptureOp
876 //===----------------------------------------------------------------------===//
877 
getFirstOp()878 Operation *AtomicCaptureOp::getFirstOp() {
879   return &getRegion().front().getOperations().front();
880 }
881 
getSecondOp()882 Operation *AtomicCaptureOp::getSecondOp() {
883   auto &ops = getRegion().front().getOperations();
884   return ops.getNextNode(ops.front());
885 }
886 
getAtomicReadOp()887 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
888   if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
889     return op;
890   return dyn_cast<AtomicReadOp>(getSecondOp());
891 }
892 
getAtomicWriteOp()893 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
894   if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
895     return op;
896   return dyn_cast<AtomicWriteOp>(getSecondOp());
897 }
898 
getAtomicUpdateOp()899 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
900   if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
901     return op;
902   return dyn_cast<AtomicUpdateOp>(getSecondOp());
903 }
904 
verify()905 LogicalResult AtomicCaptureOp::verify() {
906   return verifySynchronizationHint(*this, hint_val());
907 }
908 
verifyRegions()909 LogicalResult AtomicCaptureOp::verifyRegions() {
910   Block::OpListType &ops = region().front().getOperations();
911   if (ops.size() != 3)
912     return emitError()
913            << "expected three operations in omp.atomic.capture region (one "
914               "terminator, and two atomic ops)";
915   auto &firstOp = ops.front();
916   auto &secondOp = *ops.getNextNode(firstOp);
917   auto firstReadStmt = dyn_cast<AtomicReadOp>(firstOp);
918   auto firstUpdateStmt = dyn_cast<AtomicUpdateOp>(firstOp);
919   auto secondReadStmt = dyn_cast<AtomicReadOp>(secondOp);
920   auto secondUpdateStmt = dyn_cast<AtomicUpdateOp>(secondOp);
921   auto secondWriteStmt = dyn_cast<AtomicWriteOp>(secondOp);
922 
923   if (!((firstUpdateStmt && secondReadStmt) ||
924         (firstReadStmt && secondUpdateStmt) ||
925         (firstReadStmt && secondWriteStmt)))
926     return ops.front().emitError()
927            << "invalid sequence of operations in the capture region";
928   if (firstUpdateStmt && secondReadStmt &&
929       firstUpdateStmt.x() != secondReadStmt.x())
930     return firstUpdateStmt.emitError()
931            << "updated variable in omp.atomic.update must be captured in "
932               "second operation";
933   if (firstReadStmt && secondUpdateStmt &&
934       firstReadStmt.x() != secondUpdateStmt.x())
935     return firstReadStmt.emitError()
936            << "captured variable in omp.atomic.read must be updated in second "
937               "operation";
938   if (firstReadStmt && secondWriteStmt &&
939       firstReadStmt.x() != secondWriteStmt.address())
940     return firstReadStmt.emitError()
941            << "captured variable in omp.atomic.read must be updated in "
942               "second operation";
943 
944   if (getFirstOp()->getAttr("hint_val") || getSecondOp()->getAttr("hint_val"))
945     return emitOpError(
946         "operations inside capture region must not have hint clause");
947 
948   if (getFirstOp()->getAttr("memory_order_val") ||
949       getSecondOp()->getAttr("memory_order_val"))
950     return emitOpError(
951         "operations inside capture region must not have memory_order clause");
952   return success();
953 }
954 
955 //===----------------------------------------------------------------------===//
956 // Verifier for CancelOp
957 //===----------------------------------------------------------------------===//
958 
verify()959 LogicalResult CancelOp::verify() {
960   ClauseCancellationConstructType cct = cancellation_construct_type_val();
961   Operation *parentOp = (*this)->getParentOp();
962 
963   if (!parentOp) {
964     return emitOpError() << "must be used within a region supporting "
965                             "cancel directive";
966   }
967 
968   if ((cct == ClauseCancellationConstructType::Parallel) &&
969       !isa<ParallelOp>(parentOp)) {
970     return emitOpError() << "cancel parallel must appear "
971                          << "inside a parallel region";
972   }
973   if (cct == ClauseCancellationConstructType::Loop) {
974     if (!isa<WsLoopOp>(parentOp)) {
975       return emitOpError() << "cancel loop must appear "
976                            << "inside a worksharing-loop region";
977     }
978     if (cast<WsLoopOp>(parentOp).nowaitAttr()) {
979       return emitError() << "A worksharing construct that is canceled "
980                          << "must not have a nowait clause";
981     }
982     if (cast<WsLoopOp>(parentOp).ordered_valAttr()) {
983       return emitError() << "A worksharing construct that is canceled "
984                          << "must not have an ordered clause";
985     }
986 
987   } else if (cct == ClauseCancellationConstructType::Sections) {
988     if (!(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
989       return emitOpError() << "cancel sections must appear "
990                            << "inside a sections region";
991     }
992     if (isa_and_nonnull<SectionsOp>(parentOp->getParentOp()) &&
993         cast<SectionsOp>(parentOp->getParentOp()).nowaitAttr()) {
994       return emitError() << "A sections construct that is canceled "
995                          << "must not have a nowait clause";
996     }
997   }
998   // TODO : Add more when we support taskgroup.
999   return success();
1000 }
1001 //===----------------------------------------------------------------------===//
1002 // Verifier for CancelOp
1003 //===----------------------------------------------------------------------===//
1004 
verify()1005 LogicalResult CancellationPointOp::verify() {
1006   ClauseCancellationConstructType cct = cancellation_construct_type_val();
1007   Operation *parentOp = (*this)->getParentOp();
1008 
1009   if (!parentOp) {
1010     return emitOpError() << "must be used within a region supporting "
1011                             "cancellation point directive";
1012   }
1013 
1014   if ((cct == ClauseCancellationConstructType::Parallel) &&
1015       !(isa<ParallelOp>(parentOp))) {
1016     return emitOpError() << "cancellation point parallel must appear "
1017                          << "inside a parallel region";
1018   }
1019   if ((cct == ClauseCancellationConstructType::Loop) &&
1020       !isa<WsLoopOp>(parentOp)) {
1021     return emitOpError() << "cancellation point loop must appear "
1022                          << "inside a worksharing-loop region";
1023   }
1024   if ((cct == ClauseCancellationConstructType::Sections) &&
1025       !(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
1026     return emitOpError() << "cancellation point sections must appear "
1027                          << "inside a sections region";
1028   }
1029   // TODO : Add more when we support taskgroup.
1030   return success();
1031 }
1032 
1033 #define GET_ATTRDEF_CLASSES
1034 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
1035 
1036 #define GET_OP_CLASSES
1037 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
1038