1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the OpenMP dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
15 #include "mlir/IR/Attributes.h"
16 #include "mlir/IR/DialectImplementation.h"
17 #include "mlir/IR/OpImplementation.h"
18 #include "mlir/IR/OperationSupport.h"
19
20 #include "llvm/ADT/BitVector.h"
21 #include "llvm/ADT/SmallString.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/ADT/StringSwitch.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include <cstddef>
27
28 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
29 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
30 #include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
31 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
32
33 using namespace mlir;
34 using namespace mlir::omp;
35
36 namespace {
37 /// Model for pointer-like types that already provide a `getElementType` method.
38 template <typename T>
39 struct PointerLikeModel
40 : public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> {
getElementType__anon1411d50b0111::PointerLikeModel41 Type getElementType(Type pointer) const {
42 return pointer.cast<T>().getElementType();
43 }
44 };
45 } // namespace
46
initialize()47 void OpenMPDialect::initialize() {
48 addOperations<
49 #define GET_OP_LIST
50 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
51 >();
52 addAttributes<
53 #define GET_ATTRDEF_LIST
54 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
55 >();
56
57 LLVM::LLVMPointerType::attachInterface<
58 PointerLikeModel<LLVM::LLVMPointerType>>(*getContext());
59 MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext());
60 }
61
62 //===----------------------------------------------------------------------===//
63 // Parser and printer for Allocate Clause
64 //===----------------------------------------------------------------------===//
65
66 /// Parse an allocate clause with allocators and a list of operands with types.
67 ///
68 /// allocate-operand-list :: = allocate-operand |
69 /// allocator-operand `,` allocate-operand-list
70 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
71 /// ssa-id-and-type ::= ssa-id `:` type
parseAllocateAndAllocator(OpAsmParser & parser,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & operandsAllocate,SmallVectorImpl<Type> & typesAllocate,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & operandsAllocator,SmallVectorImpl<Type> & typesAllocator)72 static ParseResult parseAllocateAndAllocator(
73 OpAsmParser &parser,
74 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operandsAllocate,
75 SmallVectorImpl<Type> &typesAllocate,
76 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operandsAllocator,
77 SmallVectorImpl<Type> &typesAllocator) {
78
79 return parser.parseCommaSeparatedList([&]() {
80 OpAsmParser::UnresolvedOperand operand;
81 Type type;
82 if (parser.parseOperand(operand) || parser.parseColonType(type))
83 return failure();
84 operandsAllocator.push_back(operand);
85 typesAllocator.push_back(type);
86 if (parser.parseArrow())
87 return failure();
88 if (parser.parseOperand(operand) || parser.parseColonType(type))
89 return failure();
90
91 operandsAllocate.push_back(operand);
92 typesAllocate.push_back(type);
93 return success();
94 });
95 }
96
97 /// Print allocate clause
printAllocateAndAllocator(OpAsmPrinter & p,Operation * op,OperandRange varsAllocate,TypeRange typesAllocate,OperandRange varsAllocator,TypeRange typesAllocator)98 static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op,
99 OperandRange varsAllocate,
100 TypeRange typesAllocate,
101 OperandRange varsAllocator,
102 TypeRange typesAllocator) {
103 for (unsigned i = 0; i < varsAllocate.size(); ++i) {
104 std::string separator = i == varsAllocate.size() - 1 ? "" : ", ";
105 p << varsAllocator[i] << " : " << typesAllocator[i] << " -> ";
106 p << varsAllocate[i] << " : " << typesAllocate[i] << separator;
107 }
108 }
109
110 //===----------------------------------------------------------------------===//
111 // Parser and printer for a clause attribute (StringEnumAttr)
112 //===----------------------------------------------------------------------===//
113
114 template <typename ClauseAttr>
parseClauseAttr(AsmParser & parser,ClauseAttr & attr)115 static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) {
116 using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
117 StringRef enumStr;
118 SMLoc loc = parser.getCurrentLocation();
119 if (parser.parseKeyword(&enumStr))
120 return failure();
121 if (Optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
122 attr = ClauseAttr::get(parser.getContext(), *enumValue);
123 return success();
124 }
125 return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
126 }
127
128 template <typename ClauseAttr>
printClauseAttr(OpAsmPrinter & p,Operation * op,ClauseAttr attr)129 void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
130 p << stringifyEnum(attr.getValue());
131 }
132
133 //===----------------------------------------------------------------------===//
134 // Parser and printer for Linear Clause
135 //===----------------------------------------------------------------------===//
136
137 /// linear ::= `linear` `(` linear-list `)`
138 /// linear-list := linear-val | linear-val linear-list
139 /// linear-val := ssa-id-and-type `=` ssa-id-and-type
140 static ParseResult
parseLinearClause(OpAsmParser & parser,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & vars,SmallVectorImpl<Type> & types,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & stepVars)141 parseLinearClause(OpAsmParser &parser,
142 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
143 SmallVectorImpl<Type> &types,
144 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &stepVars) {
145 return parser.parseCommaSeparatedList([&]() {
146 OpAsmParser::UnresolvedOperand var;
147 Type type;
148 OpAsmParser::UnresolvedOperand stepVar;
149 if (parser.parseOperand(var) || parser.parseEqual() ||
150 parser.parseOperand(stepVar) || parser.parseColonType(type))
151 return failure();
152
153 vars.push_back(var);
154 types.push_back(type);
155 stepVars.push_back(stepVar);
156 return success();
157 });
158 }
159
160 /// Print Linear Clause
printLinearClause(OpAsmPrinter & p,Operation * op,ValueRange linearVars,TypeRange linearVarTypes,ValueRange linearStepVars)161 static void printLinearClause(OpAsmPrinter &p, Operation *op,
162 ValueRange linearVars, TypeRange linearVarTypes,
163 ValueRange linearStepVars) {
164 size_t linearVarsSize = linearVars.size();
165 for (unsigned i = 0; i < linearVarsSize; ++i) {
166 std::string separator = i == linearVarsSize - 1 ? "" : ", ";
167 p << linearVars[i];
168 if (linearStepVars.size() > i)
169 p << " = " << linearStepVars[i];
170 p << " : " << linearVars[i].getType() << separator;
171 }
172 }
173
174 //===----------------------------------------------------------------------===//
175 // Parser, printer and verifier for Schedule Clause
176 //===----------------------------------------------------------------------===//
177
178 static ParseResult
verifyScheduleModifiers(OpAsmParser & parser,SmallVectorImpl<SmallString<12>> & modifiers)179 verifyScheduleModifiers(OpAsmParser &parser,
180 SmallVectorImpl<SmallString<12>> &modifiers) {
181 if (modifiers.size() > 2)
182 return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)";
183 for (const auto &mod : modifiers) {
184 // Translate the string. If it has no value, then it was not a valid
185 // modifier!
186 auto symbol = symbolizeScheduleModifier(mod);
187 if (!symbol)
188 return parser.emitError(parser.getNameLoc())
189 << " unknown modifier type: " << mod;
190 }
191
192 // If we have one modifier that is "simd", then stick a "none" modiifer in
193 // index 0.
194 if (modifiers.size() == 1) {
195 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
196 modifiers.push_back(modifiers[0]);
197 modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
198 }
199 } else if (modifiers.size() == 2) {
200 // If there are two modifier:
201 // First modifier should not be simd, second one should be simd
202 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
203 symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
204 return parser.emitError(parser.getNameLoc())
205 << " incorrect modifier order";
206 }
207 return success();
208 }
209
210 /// schedule ::= `schedule` `(` sched-list `)`
211 /// sched-list ::= sched-val | sched-val sched-list |
212 /// sched-val `,` sched-modifier
213 /// sched-val ::= sched-with-chunk | sched-wo-chunk
214 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
215 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
216 /// sched-wo-chunk ::= `auto` | `runtime`
217 /// sched-modifier ::= sched-mod-val | sched-mod-val `,` sched-mod-val
218 /// sched-mod-val ::= `monotonic` | `nonmonotonic` | `simd` | `none`
parseScheduleClause(OpAsmParser & parser,ClauseScheduleKindAttr & scheduleAttr,ScheduleModifierAttr & scheduleModifier,UnitAttr & simdModifier,Optional<OpAsmParser::UnresolvedOperand> & chunkSize,Type & chunkType)219 static ParseResult parseScheduleClause(
220 OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr,
221 ScheduleModifierAttr &scheduleModifier, UnitAttr &simdModifier,
222 Optional<OpAsmParser::UnresolvedOperand> &chunkSize, Type &chunkType) {
223 StringRef keyword;
224 if (parser.parseKeyword(&keyword))
225 return failure();
226 llvm::Optional<mlir::omp::ClauseScheduleKind> schedule =
227 symbolizeClauseScheduleKind(keyword);
228 if (!schedule)
229 return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
230
231 scheduleAttr = ClauseScheduleKindAttr::get(parser.getContext(), *schedule);
232 switch (*schedule) {
233 case ClauseScheduleKind::Static:
234 case ClauseScheduleKind::Dynamic:
235 case ClauseScheduleKind::Guided:
236 if (succeeded(parser.parseOptionalEqual())) {
237 chunkSize = OpAsmParser::UnresolvedOperand{};
238 if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType))
239 return failure();
240 } else {
241 chunkSize = llvm::NoneType::None;
242 }
243 break;
244 case ClauseScheduleKind::Auto:
245 case ClauseScheduleKind::Runtime:
246 chunkSize = llvm::NoneType::None;
247 }
248
249 // If there is a comma, we have one or more modifiers..
250 SmallVector<SmallString<12>> modifiers;
251 while (succeeded(parser.parseOptionalComma())) {
252 StringRef mod;
253 if (parser.parseKeyword(&mod))
254 return failure();
255 modifiers.push_back(mod);
256 }
257
258 if (verifyScheduleModifiers(parser, modifiers))
259 return failure();
260
261 if (!modifiers.empty()) {
262 SMLoc loc = parser.getCurrentLocation();
263 if (Optional<ScheduleModifier> mod =
264 symbolizeScheduleModifier(modifiers[0])) {
265 scheduleModifier = ScheduleModifierAttr::get(parser.getContext(), *mod);
266 } else {
267 return parser.emitError(loc, "invalid schedule modifier");
268 }
269 // Only SIMD attribute is allowed here!
270 if (modifiers.size() > 1) {
271 assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
272 simdModifier = UnitAttr::get(parser.getBuilder().getContext());
273 }
274 }
275
276 return success();
277 }
278
279 /// Print schedule clause
printScheduleClause(OpAsmPrinter & p,Operation * op,ClauseScheduleKindAttr schedAttr,ScheduleModifierAttr modifier,UnitAttr simd,Value scheduleChunkVar,Type scheduleChunkType)280 static void printScheduleClause(OpAsmPrinter &p, Operation *op,
281 ClauseScheduleKindAttr schedAttr,
282 ScheduleModifierAttr modifier, UnitAttr simd,
283 Value scheduleChunkVar,
284 Type scheduleChunkType) {
285 p << stringifyClauseScheduleKind(schedAttr.getValue());
286 if (scheduleChunkVar)
287 p << " = " << scheduleChunkVar << " : " << scheduleChunkVar.getType();
288 if (modifier)
289 p << ", " << stringifyScheduleModifier(modifier.getValue());
290 if (simd)
291 p << ", simd";
292 }
293
294 //===----------------------------------------------------------------------===//
295 // Parser, printer and verifier for ReductionVarList
296 //===----------------------------------------------------------------------===//
297
298 /// reduction-entry-list ::= reduction-entry
299 /// | reduction-entry-list `,` reduction-entry
300 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type
301 static ParseResult
parseReductionVarList(OpAsmParser & parser,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & operands,SmallVectorImpl<Type> & types,ArrayAttr & redcuctionSymbols)302 parseReductionVarList(OpAsmParser &parser,
303 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
304 SmallVectorImpl<Type> &types,
305 ArrayAttr &redcuctionSymbols) {
306 SmallVector<SymbolRefAttr> reductionVec;
307 if (failed(parser.parseCommaSeparatedList([&]() {
308 if (parser.parseAttribute(reductionVec.emplace_back()) ||
309 parser.parseArrow() ||
310 parser.parseOperand(operands.emplace_back()) ||
311 parser.parseColonType(types.emplace_back()))
312 return failure();
313 return success();
314 })))
315 return failure();
316 SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
317 redcuctionSymbols = ArrayAttr::get(parser.getContext(), reductions);
318 return success();
319 }
320
321 /// Print Reduction clause
printReductionVarList(OpAsmPrinter & p,Operation * op,OperandRange reductionVars,TypeRange reductionTypes,Optional<ArrayAttr> reductions)322 static void printReductionVarList(OpAsmPrinter &p, Operation *op,
323 OperandRange reductionVars,
324 TypeRange reductionTypes,
325 Optional<ArrayAttr> reductions) {
326 for (unsigned i = 0, e = reductions->size(); i < e; ++i) {
327 if (i != 0)
328 p << ", ";
329 p << (*reductions)[i] << " -> " << reductionVars[i] << " : "
330 << reductionVars[i].getType();
331 }
332 }
333
334 /// Verifies Reduction Clause
verifyReductionVarList(Operation * op,Optional<ArrayAttr> reductions,OperandRange reductionVars)335 static LogicalResult verifyReductionVarList(Operation *op,
336 Optional<ArrayAttr> reductions,
337 OperandRange reductionVars) {
338 if (!reductionVars.empty()) {
339 if (!reductions || reductions->size() != reductionVars.size())
340 return op->emitOpError()
341 << "expected as many reduction symbol references "
342 "as reduction variables";
343 } else {
344 if (reductions)
345 return op->emitOpError() << "unexpected reduction symbol references";
346 return success();
347 }
348
349 // TODO: The followings should be done in
350 // SymbolUserOpInterface::verifySymbolUses.
351 DenseSet<Value> accumulators;
352 for (auto args : llvm::zip(reductionVars, *reductions)) {
353 Value accum = std::get<0>(args);
354
355 if (!accumulators.insert(accum).second)
356 return op->emitOpError() << "accumulator variable used more than once";
357
358 Type varType = accum.getType().cast<PointerLikeType>();
359 auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>();
360 auto decl =
361 SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef);
362 if (!decl)
363 return op->emitOpError() << "expected symbol reference " << symbolRef
364 << " to point to a reduction declaration";
365
366 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
367 return op->emitOpError()
368 << "expected accumulator (" << varType
369 << ") to be the same type as reduction declaration ("
370 << decl.getAccumulatorType() << ")";
371 }
372
373 return success();
374 }
375
376 //===----------------------------------------------------------------------===//
377 // Parser, printer and verifier for Synchronization Hint (2.17.12)
378 //===----------------------------------------------------------------------===//
379
380 /// Parses a Synchronization Hint clause. The value of hint is an integer
381 /// which is a combination of different hints from `omp_sync_hint_t`.
382 ///
383 /// hint-clause = `hint` `(` hint-value `)`
parseSynchronizationHint(OpAsmParser & parser,IntegerAttr & hintAttr)384 static ParseResult parseSynchronizationHint(OpAsmParser &parser,
385 IntegerAttr &hintAttr) {
386 StringRef hintKeyword;
387 int64_t hint = 0;
388 if (succeeded(parser.parseOptionalKeyword("none"))) {
389 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
390 return success();
391 }
392 auto parseKeyword = [&]() -> ParseResult {
393 if (failed(parser.parseKeyword(&hintKeyword)))
394 return failure();
395 if (hintKeyword == "uncontended")
396 hint |= 1;
397 else if (hintKeyword == "contended")
398 hint |= 2;
399 else if (hintKeyword == "nonspeculative")
400 hint |= 4;
401 else if (hintKeyword == "speculative")
402 hint |= 8;
403 else
404 return parser.emitError(parser.getCurrentLocation())
405 << hintKeyword << " is not a valid hint";
406 return success();
407 };
408 if (parser.parseCommaSeparatedList(parseKeyword))
409 return failure();
410 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
411 return success();
412 }
413
414 /// Prints a Synchronization Hint clause
printSynchronizationHint(OpAsmPrinter & p,Operation * op,IntegerAttr hintAttr)415 static void printSynchronizationHint(OpAsmPrinter &p, Operation *op,
416 IntegerAttr hintAttr) {
417 int64_t hint = hintAttr.getInt();
418
419 if (hint == 0) {
420 p << "none";
421 return;
422 }
423
424 // Helper function to get n-th bit from the right end of `value`
425 auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
426
427 bool uncontended = bitn(hint, 0);
428 bool contended = bitn(hint, 1);
429 bool nonspeculative = bitn(hint, 2);
430 bool speculative = bitn(hint, 3);
431
432 SmallVector<StringRef> hints;
433 if (uncontended)
434 hints.push_back("uncontended");
435 if (contended)
436 hints.push_back("contended");
437 if (nonspeculative)
438 hints.push_back("nonspeculative");
439 if (speculative)
440 hints.push_back("speculative");
441
442 llvm::interleaveComma(hints, p);
443 }
444
445 /// Verifies a synchronization hint clause
verifySynchronizationHint(Operation * op,uint64_t hint)446 static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
447
448 // Helper function to get n-th bit from the right end of `value`
449 auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
450
451 bool uncontended = bitn(hint, 0);
452 bool contended = bitn(hint, 1);
453 bool nonspeculative = bitn(hint, 2);
454 bool speculative = bitn(hint, 3);
455
456 if (uncontended && contended)
457 return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
458 "omp_sync_hint_contended cannot be combined";
459 if (nonspeculative && speculative)
460 return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
461 "omp_sync_hint_speculative cannot be combined.";
462 return success();
463 }
464
465 //===----------------------------------------------------------------------===//
466 // ParallelOp
467 //===----------------------------------------------------------------------===//
468
build(OpBuilder & builder,OperationState & state,ArrayRef<NamedAttribute> attributes)469 void ParallelOp::build(OpBuilder &builder, OperationState &state,
470 ArrayRef<NamedAttribute> attributes) {
471 ParallelOp::build(
472 builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
473 /*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(),
474 /*reduction_vars=*/ValueRange(), /*reductions=*/nullptr,
475 /*proc_bind_val=*/nullptr);
476 state.addAttributes(attributes);
477 }
478
verify()479 LogicalResult ParallelOp::verify() {
480 if (allocate_vars().size() != allocators_vars().size())
481 return emitError(
482 "expected equal sizes for allocate and allocator variables");
483 return verifyReductionVarList(*this, reductions(), reduction_vars());
484 }
485
486 //===----------------------------------------------------------------------===//
487 // Verifier for SectionsOp
488 //===----------------------------------------------------------------------===//
489
verify()490 LogicalResult SectionsOp::verify() {
491 if (allocate_vars().size() != allocators_vars().size())
492 return emitError(
493 "expected equal sizes for allocate and allocator variables");
494
495 return verifyReductionVarList(*this, reductions(), reduction_vars());
496 }
497
verifyRegions()498 LogicalResult SectionsOp::verifyRegions() {
499 for (auto &inst : *region().begin()) {
500 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
501 return emitOpError()
502 << "expected omp.section op or terminator op inside region";
503 }
504 }
505
506 return success();
507 }
508
verify()509 LogicalResult SingleOp::verify() {
510 // Check for allocate clause restrictions
511 if (allocate_vars().size() != allocators_vars().size())
512 return emitError(
513 "expected equal sizes for allocate and allocator variables");
514
515 return success();
516 }
517
518 //===----------------------------------------------------------------------===//
519 // WsLoopOp
520 //===----------------------------------------------------------------------===//
521
522 /// loop-control ::= `(` ssa-id-list `)` `:` type `=` loop-bounds
523 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` inclusive? steps
524 /// steps := `step` `(`ssa-id-list`)`
525 ParseResult
parseLoopControl(OpAsmParser & parser,Region & region,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & lowerBound,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & upperBound,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & steps,SmallVectorImpl<Type> & loopVarTypes,UnitAttr & inclusive)526 parseLoopControl(OpAsmParser &parser, Region ®ion,
527 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &lowerBound,
528 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &upperBound,
529 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &steps,
530 SmallVectorImpl<Type> &loopVarTypes, UnitAttr &inclusive) {
531 // Parse an opening `(` followed by induction variables followed by `)`
532 SmallVector<OpAsmParser::Argument> ivs;
533 Type loopVarType;
534 if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren) ||
535 parser.parseColonType(loopVarType) ||
536 // Parse loop bounds.
537 parser.parseEqual() ||
538 parser.parseOperandList(lowerBound, ivs.size(),
539 OpAsmParser::Delimiter::Paren) ||
540 parser.parseKeyword("to") ||
541 parser.parseOperandList(upperBound, ivs.size(),
542 OpAsmParser::Delimiter::Paren))
543 return failure();
544
545 if (succeeded(parser.parseOptionalKeyword("inclusive")))
546 inclusive = UnitAttr::get(parser.getBuilder().getContext());
547
548 // Parse step values.
549 if (parser.parseKeyword("step") ||
550 parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren))
551 return failure();
552
553 // Now parse the body.
554 loopVarTypes = SmallVector<Type>(ivs.size(), loopVarType);
555 for (auto &iv : ivs)
556 iv.type = loopVarType;
557 return parser.parseRegion(region, ivs);
558 }
559
printLoopControl(OpAsmPrinter & p,Operation * op,Region & region,ValueRange lowerBound,ValueRange upperBound,ValueRange steps,TypeRange loopVarTypes,UnitAttr inclusive)560 void printLoopControl(OpAsmPrinter &p, Operation *op, Region ®ion,
561 ValueRange lowerBound, ValueRange upperBound,
562 ValueRange steps, TypeRange loopVarTypes,
563 UnitAttr inclusive) {
564 auto args = region.front().getArguments();
565 p << " (" << args << ") : " << args[0].getType() << " = (" << lowerBound
566 << ") to (" << upperBound << ") ";
567 if (inclusive)
568 p << "inclusive ";
569 p << "step (" << steps << ") ";
570 p.printRegion(region, /*printEntryBlockArgs=*/false);
571 }
572
573 //===----------------------------------------------------------------------===//
574 // Verifier for Simd construct [2.9.3.1]
575 //===----------------------------------------------------------------------===//
576
verify()577 LogicalResult SimdLoopOp::verify() {
578 if (this->lowerBound().empty()) {
579 return emitOpError() << "empty lowerbound for simd loop operation";
580 }
581 return success();
582 }
583
584 //===----------------------------------------------------------------------===//
585 // ReductionOp
586 //===----------------------------------------------------------------------===//
587
parseAtomicReductionRegion(OpAsmParser & parser,Region & region)588 static ParseResult parseAtomicReductionRegion(OpAsmParser &parser,
589 Region ®ion) {
590 if (parser.parseOptionalKeyword("atomic"))
591 return success();
592 return parser.parseRegion(region);
593 }
594
printAtomicReductionRegion(OpAsmPrinter & printer,ReductionDeclareOp op,Region & region)595 static void printAtomicReductionRegion(OpAsmPrinter &printer,
596 ReductionDeclareOp op, Region ®ion) {
597 if (region.empty())
598 return;
599 printer << "atomic ";
600 printer.printRegion(region);
601 }
602
verifyRegions()603 LogicalResult ReductionDeclareOp::verifyRegions() {
604 if (initializerRegion().empty())
605 return emitOpError() << "expects non-empty initializer region";
606 Block &initializerEntryBlock = initializerRegion().front();
607 if (initializerEntryBlock.getNumArguments() != 1 ||
608 initializerEntryBlock.getArgument(0).getType() != type()) {
609 return emitOpError() << "expects initializer region with one argument "
610 "of the reduction type";
611 }
612
613 for (YieldOp yieldOp : initializerRegion().getOps<YieldOp>()) {
614 if (yieldOp.results().size() != 1 ||
615 yieldOp.results().getTypes()[0] != type())
616 return emitOpError() << "expects initializer region to yield a value "
617 "of the reduction type";
618 }
619
620 if (reductionRegion().empty())
621 return emitOpError() << "expects non-empty reduction region";
622 Block &reductionEntryBlock = reductionRegion().front();
623 if (reductionEntryBlock.getNumArguments() != 2 ||
624 reductionEntryBlock.getArgumentTypes()[0] !=
625 reductionEntryBlock.getArgumentTypes()[1] ||
626 reductionEntryBlock.getArgumentTypes()[0] != type())
627 return emitOpError() << "expects reduction region with two arguments of "
628 "the reduction type";
629 for (YieldOp yieldOp : reductionRegion().getOps<YieldOp>()) {
630 if (yieldOp.results().size() != 1 ||
631 yieldOp.results().getTypes()[0] != type())
632 return emitOpError() << "expects reduction region to yield a value "
633 "of the reduction type";
634 }
635
636 if (atomicReductionRegion().empty())
637 return success();
638
639 Block &atomicReductionEntryBlock = atomicReductionRegion().front();
640 if (atomicReductionEntryBlock.getNumArguments() != 2 ||
641 atomicReductionEntryBlock.getArgumentTypes()[0] !=
642 atomicReductionEntryBlock.getArgumentTypes()[1])
643 return emitOpError() << "expects atomic reduction region with two "
644 "arguments of the same type";
645 auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0]
646 .dyn_cast<PointerLikeType>();
647 if (!ptrType || ptrType.getElementType() != type())
648 return emitOpError() << "expects atomic reduction region arguments to "
649 "be accumulators containing the reduction type";
650 return success();
651 }
652
verify()653 LogicalResult ReductionOp::verify() {
654 auto *op = (*this)->getParentWithTrait<ReductionClauseInterface::Trait>();
655 if (!op)
656 return emitOpError() << "must be used within an operation supporting "
657 "reduction clause interface";
658 while (op) {
659 for (const auto &var :
660 cast<ReductionClauseInterface>(op).getReductionVars())
661 if (var == accumulator())
662 return success();
663 op = op->getParentWithTrait<ReductionClauseInterface::Trait>();
664 }
665 return emitOpError() << "the accumulator is not used by the parent";
666 }
667
668 //===----------------------------------------------------------------------===//
669 // TaskOp
670 //===----------------------------------------------------------------------===//
verify()671 LogicalResult TaskOp::verify() {
672 return verifyReductionVarList(*this, in_reductions(), in_reduction_vars());
673 }
674
675 //===----------------------------------------------------------------------===//
676 // TaskGroupOp
677 //===----------------------------------------------------------------------===//
verify()678 LogicalResult TaskGroupOp::verify() {
679 return verifyReductionVarList(*this, task_reductions(),
680 task_reduction_vars());
681 }
682
683 //===----------------------------------------------------------------------===//
684 // TaskLoopOp
685 //===----------------------------------------------------------------------===//
getReductionVars()686 SmallVector<Value> TaskLoopOp::getReductionVars() {
687 SmallVector<Value> all_reduction_nvars(in_reduction_vars().begin(),
688 in_reduction_vars().end());
689 all_reduction_nvars.insert(all_reduction_nvars.end(),
690 reduction_vars().begin(), reduction_vars().end());
691 return all_reduction_nvars;
692 }
693
verify()694 LogicalResult TaskLoopOp::verify() {
695 if (allocate_vars().size() != allocators_vars().size())
696 return emitError(
697 "expected equal sizes for allocate and allocator variables");
698 if (failed(verifyReductionVarList(*this, reductions(), reduction_vars())) ||
699 failed(
700 verifyReductionVarList(*this, in_reductions(), in_reduction_vars())))
701 return failure();
702
703 if (reduction_vars().size() > 0 && nogroup())
704 return emitError("if a reduction clause is present on the taskloop "
705 "directive, the nogroup clause must not be specified");
706 for (auto var : reduction_vars()) {
707 if (llvm::is_contained(in_reduction_vars(), var))
708 return emitError("the same list item cannot appear in both a reduction "
709 "and an in_reduction clause");
710 }
711
712 if (grain_size() && num_tasks()) {
713 return emitError(
714 "the grainsize clause and num_tasks clause are mutually exclusive and "
715 "may not appear on the same taskloop directive");
716 }
717 return success();
718 }
719
720 //===----------------------------------------------------------------------===//
721 // WsLoopOp
722 //===----------------------------------------------------------------------===//
723
build(OpBuilder & builder,OperationState & state,ValueRange lowerBound,ValueRange upperBound,ValueRange step,ArrayRef<NamedAttribute> attributes)724 void WsLoopOp::build(OpBuilder &builder, OperationState &state,
725 ValueRange lowerBound, ValueRange upperBound,
726 ValueRange step, ArrayRef<NamedAttribute> attributes) {
727 build(builder, state, lowerBound, upperBound, step,
728 /*linear_vars=*/ValueRange(),
729 /*linear_step_vars=*/ValueRange(), /*reduction_vars=*/ValueRange(),
730 /*reductions=*/nullptr, /*schedule_val=*/nullptr,
731 /*schedule_chunk_var=*/nullptr, /*schedule_modifier=*/nullptr,
732 /*simd_modifier=*/false, /*nowait=*/false, /*ordered_val=*/nullptr,
733 /*order_val=*/nullptr, /*inclusive=*/false);
734 state.addAttributes(attributes);
735 }
736
verify()737 LogicalResult WsLoopOp::verify() {
738 return verifyReductionVarList(*this, reductions(), reduction_vars());
739 }
740
741 //===----------------------------------------------------------------------===//
742 // Verifier for critical construct (2.17.1)
743 //===----------------------------------------------------------------------===//
744
verify()745 LogicalResult CriticalDeclareOp::verify() {
746 return verifySynchronizationHint(*this, hint_val());
747 }
748
verifySymbolUses(SymbolTableCollection & symbolTable)749 LogicalResult CriticalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
750 if (nameAttr()) {
751 SymbolRefAttr symbolRef = nameAttr();
752 auto decl = symbolTable.lookupNearestSymbolFrom<CriticalDeclareOp>(
753 *this, symbolRef);
754 if (!decl) {
755 return emitOpError() << "expected symbol reference " << symbolRef
756 << " to point to a critical declaration";
757 }
758 }
759
760 return success();
761 }
762
763 //===----------------------------------------------------------------------===//
764 // Verifier for ordered construct
765 //===----------------------------------------------------------------------===//
766
verify()767 LogicalResult OrderedOp::verify() {
768 auto container = (*this)->getParentOfType<WsLoopOp>();
769 if (!container || !container.ordered_valAttr() ||
770 container.ordered_valAttr().getInt() == 0)
771 return emitOpError() << "ordered depend directive must be closely "
772 << "nested inside a worksharing-loop with ordered "
773 << "clause with parameter present";
774
775 if (container.ordered_valAttr().getInt() != (int64_t)*num_loops_val())
776 return emitOpError() << "number of variables in depend clause does not "
777 << "match number of iteration variables in the "
778 << "doacross loop";
779
780 return success();
781 }
782
verify()783 LogicalResult OrderedRegionOp::verify() {
784 // TODO: The code generation for ordered simd directive is not supported yet.
785 if (simd())
786 return failure();
787
788 if (auto container = (*this)->getParentOfType<WsLoopOp>()) {
789 if (!container.ordered_valAttr() ||
790 container.ordered_valAttr().getInt() != 0)
791 return emitOpError() << "ordered region must be closely nested inside "
792 << "a worksharing-loop region with an ordered "
793 << "clause without parameter present";
794 }
795
796 return success();
797 }
798
799 //===----------------------------------------------------------------------===//
800 // Verifier for AtomicReadOp
801 //===----------------------------------------------------------------------===//
802
verify()803 LogicalResult AtomicReadOp::verify() {
804 if (auto mo = memory_order_val()) {
805 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
806 *mo == ClauseMemoryOrderKind::Release) {
807 return emitError(
808 "memory-order must not be acq_rel or release for atomic reads");
809 }
810 }
811 if (x() == v())
812 return emitError(
813 "read and write must not be to the same location for atomic reads");
814 return verifySynchronizationHint(*this, hint_val());
815 }
816
817 //===----------------------------------------------------------------------===//
818 // Verifier for AtomicWriteOp
819 //===----------------------------------------------------------------------===//
820
verify()821 LogicalResult AtomicWriteOp::verify() {
822 if (auto mo = memory_order_val()) {
823 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
824 *mo == ClauseMemoryOrderKind::Acquire) {
825 return emitError(
826 "memory-order must not be acq_rel or acquire for atomic writes");
827 }
828 }
829 if (address().getType().cast<PointerLikeType>().getElementType() !=
830 value().getType())
831 return emitError("address must dereference to value type");
832 return verifySynchronizationHint(*this, hint_val());
833 }
834
835 //===----------------------------------------------------------------------===//
836 // Verifier for AtomicUpdateOp
837 //===----------------------------------------------------------------------===//
838
verify()839 LogicalResult AtomicUpdateOp::verify() {
840 if (auto mo = memory_order_val()) {
841 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
842 *mo == ClauseMemoryOrderKind::Acquire) {
843 return emitError(
844 "memory-order must not be acq_rel or acquire for atomic updates");
845 }
846 }
847
848 if (x().getType().cast<PointerLikeType>().getElementType() !=
849 region().getArgument(0).getType()) {
850 return emitError("the type of the operand must be a pointer type whose "
851 "element type is the same as that of the region argument");
852 }
853
854 return verifySynchronizationHint(*this, hint_val());
855 }
856
verifyRegions()857 LogicalResult AtomicUpdateOp::verifyRegions() {
858 if (region().getNumArguments() != 1)
859 return emitError("the region must accept exactly one argument");
860
861 if (region().front().getOperations().size() < 2)
862 return emitError() << "the update region must have at least two operations "
863 "(binop and terminator)";
864
865 YieldOp yieldOp = *region().getOps<YieldOp>().begin();
866
867 if (yieldOp.results().size() != 1)
868 return emitError("only updated value must be returned");
869 if (yieldOp.results().front().getType() != region().getArgument(0).getType())
870 return emitError("input and yielded value must have the same type");
871 return success();
872 }
873
874 //===----------------------------------------------------------------------===//
875 // Verifier for AtomicCaptureOp
876 //===----------------------------------------------------------------------===//
877
getFirstOp()878 Operation *AtomicCaptureOp::getFirstOp() {
879 return &getRegion().front().getOperations().front();
880 }
881
getSecondOp()882 Operation *AtomicCaptureOp::getSecondOp() {
883 auto &ops = getRegion().front().getOperations();
884 return ops.getNextNode(ops.front());
885 }
886
getAtomicReadOp()887 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
888 if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
889 return op;
890 return dyn_cast<AtomicReadOp>(getSecondOp());
891 }
892
getAtomicWriteOp()893 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
894 if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
895 return op;
896 return dyn_cast<AtomicWriteOp>(getSecondOp());
897 }
898
getAtomicUpdateOp()899 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
900 if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
901 return op;
902 return dyn_cast<AtomicUpdateOp>(getSecondOp());
903 }
904
verify()905 LogicalResult AtomicCaptureOp::verify() {
906 return verifySynchronizationHint(*this, hint_val());
907 }
908
verifyRegions()909 LogicalResult AtomicCaptureOp::verifyRegions() {
910 Block::OpListType &ops = region().front().getOperations();
911 if (ops.size() != 3)
912 return emitError()
913 << "expected three operations in omp.atomic.capture region (one "
914 "terminator, and two atomic ops)";
915 auto &firstOp = ops.front();
916 auto &secondOp = *ops.getNextNode(firstOp);
917 auto firstReadStmt = dyn_cast<AtomicReadOp>(firstOp);
918 auto firstUpdateStmt = dyn_cast<AtomicUpdateOp>(firstOp);
919 auto secondReadStmt = dyn_cast<AtomicReadOp>(secondOp);
920 auto secondUpdateStmt = dyn_cast<AtomicUpdateOp>(secondOp);
921 auto secondWriteStmt = dyn_cast<AtomicWriteOp>(secondOp);
922
923 if (!((firstUpdateStmt && secondReadStmt) ||
924 (firstReadStmt && secondUpdateStmt) ||
925 (firstReadStmt && secondWriteStmt)))
926 return ops.front().emitError()
927 << "invalid sequence of operations in the capture region";
928 if (firstUpdateStmt && secondReadStmt &&
929 firstUpdateStmt.x() != secondReadStmt.x())
930 return firstUpdateStmt.emitError()
931 << "updated variable in omp.atomic.update must be captured in "
932 "second operation";
933 if (firstReadStmt && secondUpdateStmt &&
934 firstReadStmt.x() != secondUpdateStmt.x())
935 return firstReadStmt.emitError()
936 << "captured variable in omp.atomic.read must be updated in second "
937 "operation";
938 if (firstReadStmt && secondWriteStmt &&
939 firstReadStmt.x() != secondWriteStmt.address())
940 return firstReadStmt.emitError()
941 << "captured variable in omp.atomic.read must be updated in "
942 "second operation";
943
944 if (getFirstOp()->getAttr("hint_val") || getSecondOp()->getAttr("hint_val"))
945 return emitOpError(
946 "operations inside capture region must not have hint clause");
947
948 if (getFirstOp()->getAttr("memory_order_val") ||
949 getSecondOp()->getAttr("memory_order_val"))
950 return emitOpError(
951 "operations inside capture region must not have memory_order clause");
952 return success();
953 }
954
955 //===----------------------------------------------------------------------===//
956 // Verifier for CancelOp
957 //===----------------------------------------------------------------------===//
958
verify()959 LogicalResult CancelOp::verify() {
960 ClauseCancellationConstructType cct = cancellation_construct_type_val();
961 Operation *parentOp = (*this)->getParentOp();
962
963 if (!parentOp) {
964 return emitOpError() << "must be used within a region supporting "
965 "cancel directive";
966 }
967
968 if ((cct == ClauseCancellationConstructType::Parallel) &&
969 !isa<ParallelOp>(parentOp)) {
970 return emitOpError() << "cancel parallel must appear "
971 << "inside a parallel region";
972 }
973 if (cct == ClauseCancellationConstructType::Loop) {
974 if (!isa<WsLoopOp>(parentOp)) {
975 return emitOpError() << "cancel loop must appear "
976 << "inside a worksharing-loop region";
977 }
978 if (cast<WsLoopOp>(parentOp).nowaitAttr()) {
979 return emitError() << "A worksharing construct that is canceled "
980 << "must not have a nowait clause";
981 }
982 if (cast<WsLoopOp>(parentOp).ordered_valAttr()) {
983 return emitError() << "A worksharing construct that is canceled "
984 << "must not have an ordered clause";
985 }
986
987 } else if (cct == ClauseCancellationConstructType::Sections) {
988 if (!(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
989 return emitOpError() << "cancel sections must appear "
990 << "inside a sections region";
991 }
992 if (isa_and_nonnull<SectionsOp>(parentOp->getParentOp()) &&
993 cast<SectionsOp>(parentOp->getParentOp()).nowaitAttr()) {
994 return emitError() << "A sections construct that is canceled "
995 << "must not have a nowait clause";
996 }
997 }
998 // TODO : Add more when we support taskgroup.
999 return success();
1000 }
1001 //===----------------------------------------------------------------------===//
1002 // Verifier for CancelOp
1003 //===----------------------------------------------------------------------===//
1004
verify()1005 LogicalResult CancellationPointOp::verify() {
1006 ClauseCancellationConstructType cct = cancellation_construct_type_val();
1007 Operation *parentOp = (*this)->getParentOp();
1008
1009 if (!parentOp) {
1010 return emitOpError() << "must be used within a region supporting "
1011 "cancellation point directive";
1012 }
1013
1014 if ((cct == ClauseCancellationConstructType::Parallel) &&
1015 !(isa<ParallelOp>(parentOp))) {
1016 return emitOpError() << "cancellation point parallel must appear "
1017 << "inside a parallel region";
1018 }
1019 if ((cct == ClauseCancellationConstructType::Loop) &&
1020 !isa<WsLoopOp>(parentOp)) {
1021 return emitOpError() << "cancellation point loop must appear "
1022 << "inside a worksharing-loop region";
1023 }
1024 if ((cct == ClauseCancellationConstructType::Sections) &&
1025 !(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
1026 return emitOpError() << "cancellation point sections must appear "
1027 << "inside a sections region";
1028 }
1029 // TODO : Add more when we support taskgroup.
1030 return success();
1031 }
1032
1033 #define GET_ATTRDEF_CLASSES
1034 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
1035
1036 #define GET_OP_CLASSES
1037 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
1038