1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the OpenMP dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/OpImplementation.h"
18 #include "mlir/IR/OperationSupport.h"
19 
20 #include "llvm/ADT/SmallString.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/StringRef.h"
23 #include "llvm/ADT/StringSwitch.h"
24 #include <cstddef>
25 
26 #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
27 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
28 #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
29 
30 using namespace mlir;
31 using namespace mlir::omp;
32 
33 namespace {
34 /// Model for pointer-like types that already provide a `getElementType` method.
35 template <typename T>
36 struct PointerLikeModel
37     : public PointerLikeType::ExternalModel<PointerLikeModel<T>, T> {
38   Type getElementType(Type pointer) const {
39     return pointer.cast<T>().getElementType();
40   }
41 };
42 } // end namespace
43 
44 void OpenMPDialect::initialize() {
45   addOperations<
46 #define GET_OP_LIST
47 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
48       >();
49 
50   LLVM::LLVMPointerType::attachInterface<
51       PointerLikeModel<LLVM::LLVMPointerType>>(*getContext());
52   MemRefType::attachInterface<PointerLikeModel<MemRefType>>(*getContext());
53 }
54 
55 //===----------------------------------------------------------------------===//
56 // ParallelOp
57 //===----------------------------------------------------------------------===//
58 
59 void ParallelOp::build(OpBuilder &builder, OperationState &state,
60                        ArrayRef<NamedAttribute> attributes) {
61   ParallelOp::build(
62       builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
63       /*default_val=*/nullptr, /*private_vars=*/ValueRange(),
64       /*firstprivate_vars=*/ValueRange(), /*shared_vars=*/ValueRange(),
65       /*copyin_vars=*/ValueRange(), /*allocate_vars=*/ValueRange(),
66       /*allocators_vars=*/ValueRange(), /*proc_bind_val=*/nullptr);
67   state.addAttributes(attributes);
68 }
69 
70 /// Parse a list of operands with types.
71 ///
72 /// operand-and-type-list ::= `(` ssa-id-and-type-list `)`
73 /// ssa-id-and-type-list ::= ssa-id-and-type |
74 ///                          ssa-id-and-type `,` ssa-id-and-type-list
75 /// ssa-id-and-type ::= ssa-id `:` type
76 static ParseResult
77 parseOperandAndTypeList(OpAsmParser &parser,
78                         SmallVectorImpl<OpAsmParser::OperandType> &operands,
79                         SmallVectorImpl<Type> &types) {
80   if (parser.parseLParen())
81     return failure();
82 
83   do {
84     OpAsmParser::OperandType operand;
85     Type type;
86     if (parser.parseOperand(operand) || parser.parseColonType(type))
87       return failure();
88     operands.push_back(operand);
89     types.push_back(type);
90   } while (succeeded(parser.parseOptionalComma()));
91 
92   if (parser.parseRParen())
93     return failure();
94 
95   return success();
96 }
97 
98 /// Parse an allocate clause with allocators and a list of operands with types.
99 ///
100 /// operand-and-type-list ::= `(` allocate-operand-list `)`
101 /// allocate-operand-list :: = allocate-operand |
102 ///                            allocator-operand `,` allocate-operand-list
103 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
104 /// ssa-id-and-type ::= ssa-id `:` type
105 static ParseResult parseAllocateAndAllocator(
106     OpAsmParser &parser,
107     SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate,
108     SmallVectorImpl<Type> &typesAllocate,
109     SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator,
110     SmallVectorImpl<Type> &typesAllocator) {
111   if (parser.parseLParen())
112     return failure();
113 
114   do {
115     OpAsmParser::OperandType operand;
116     Type type;
117 
118     if (parser.parseOperand(operand) || parser.parseColonType(type))
119       return failure();
120     operandsAllocator.push_back(operand);
121     typesAllocator.push_back(type);
122     if (parser.parseArrow())
123       return failure();
124     if (parser.parseOperand(operand) || parser.parseColonType(type))
125       return failure();
126 
127     operandsAllocate.push_back(operand);
128     typesAllocate.push_back(type);
129   } while (succeeded(parser.parseOptionalComma()));
130 
131   if (parser.parseRParen())
132     return failure();
133 
134   return success();
135 }
136 
137 static LogicalResult verifyParallelOp(ParallelOp op) {
138   if (op.allocate_vars().size() != op.allocators_vars().size())
139     return op.emitError(
140         "expected equal sizes for allocate and allocator variables");
141   return success();
142 }
143 
144 static void printParallelOp(OpAsmPrinter &p, ParallelOp op) {
145   if (auto ifCond = op.if_expr_var())
146     p << " if(" << ifCond << " : " << ifCond.getType() << ")";
147 
148   if (auto threads = op.num_threads_var())
149     p << " num_threads(" << threads << " : " << threads.getType() << ")";
150 
151   // Print private, firstprivate, shared and copyin parameters
152   auto printDataVars = [&p](StringRef name, OperandRange vars) {
153     if (vars.size()) {
154       p << " " << name << "(";
155       for (unsigned i = 0; i < vars.size(); ++i) {
156         std::string separator = i == vars.size() - 1 ? ")" : ", ";
157         p << vars[i] << " : " << vars[i].getType() << separator;
158       }
159     }
160   };
161 
162   // Print allocator and allocate parameters
163   auto printAllocateAndAllocator = [&p](OperandRange varsAllocate,
164                                         OperandRange varsAllocator) {
165     if (varsAllocate.empty())
166       return;
167 
168     p << " allocate(";
169     for (unsigned i = 0; i < varsAllocate.size(); ++i) {
170       std::string separator = i == varsAllocate.size() - 1 ? ")" : ", ";
171       p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> ";
172       p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator;
173     }
174   };
175 
176   printDataVars("private", op.private_vars());
177   printDataVars("firstprivate", op.firstprivate_vars());
178   printDataVars("shared", op.shared_vars());
179   printDataVars("copyin", op.copyin_vars());
180   printAllocateAndAllocator(op.allocate_vars(), op.allocators_vars());
181 
182   if (auto def = op.default_val())
183     p << " default(" << def->drop_front(3) << ")";
184 
185   if (auto bind = op.proc_bind_val())
186     p << " proc_bind(" << bind << ")";
187 
188   p.printRegion(op.getRegion());
189 }
190 
191 /// Emit an error if the same clause is present more than once on an operation.
192 static ParseResult allowedOnce(OpAsmParser &parser, StringRef clause,
193                                StringRef operation) {
194   return parser.emitError(parser.getNameLoc())
195          << " at most one " << clause << " clause can appear on the "
196          << operation << " operation";
197 }
198 
199 /// Parses a parallel operation.
200 ///
201 /// operation ::= `omp.parallel` clause-list
202 /// clause-list ::= clause | clause clause-list
203 /// clause ::= if | numThreads | private | firstprivate | shared | copyin |
204 ///            default | procBind
205 /// if ::= `if` `(` ssa-id `)`
206 /// numThreads ::= `num_threads` `(` ssa-id-and-type `)`
207 /// private ::= `private` operand-and-type-list
208 /// firstprivate ::= `firstprivate` operand-and-type-list
209 /// shared ::= `shared` operand-and-type-list
210 /// copyin ::= `copyin` operand-and-type-list
211 /// allocate ::= `allocate` operand-and-type `->` operand-and-type-list
212 /// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`)
213 /// procBind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)`
214 ///
215 /// Note that each clause can only appear once in the clase-list.
216 static ParseResult parseParallelOp(OpAsmParser &parser,
217                                    OperationState &result) {
218   std::pair<OpAsmParser::OperandType, Type> ifCond;
219   std::pair<OpAsmParser::OperandType, Type> numThreads;
220   SmallVector<OpAsmParser::OperandType, 4> privates;
221   SmallVector<Type, 4> privateTypes;
222   SmallVector<OpAsmParser::OperandType, 4> firstprivates;
223   SmallVector<Type, 4> firstprivateTypes;
224   SmallVector<OpAsmParser::OperandType, 4> shareds;
225   SmallVector<Type, 4> sharedTypes;
226   SmallVector<OpAsmParser::OperandType, 4> copyins;
227   SmallVector<Type, 4> copyinTypes;
228   SmallVector<OpAsmParser::OperandType, 4> allocates;
229   SmallVector<Type, 4> allocateTypes;
230   SmallVector<OpAsmParser::OperandType, 4> allocators;
231   SmallVector<Type, 4> allocatorTypes;
232   std::array<int, 8> segments{0, 0, 0, 0, 0, 0, 0, 0};
233   StringRef keyword;
234   bool defaultVal = false;
235   bool procBind = false;
236 
237   const int ifClausePos = 0;
238   const int numThreadsClausePos = 1;
239   const int privateClausePos = 2;
240   const int firstprivateClausePos = 3;
241   const int sharedClausePos = 4;
242   const int copyinClausePos = 5;
243   const int allocateClausePos = 6;
244   const int allocatorPos = 7;
245   const StringRef opName = result.name.getStringRef();
246 
247   while (succeeded(parser.parseOptionalKeyword(&keyword))) {
248     if (keyword == "if") {
249       // Fail if there was already another if condition.
250       if (segments[ifClausePos])
251         return allowedOnce(parser, "if", opName);
252       if (parser.parseLParen() || parser.parseOperand(ifCond.first) ||
253           parser.parseColonType(ifCond.second) || parser.parseRParen())
254         return failure();
255       segments[ifClausePos] = 1;
256     } else if (keyword == "num_threads") {
257       // Fail if there was already another num_threads clause.
258       if (segments[numThreadsClausePos])
259         return allowedOnce(parser, "num_threads", opName);
260       if (parser.parseLParen() || parser.parseOperand(numThreads.first) ||
261           parser.parseColonType(numThreads.second) || parser.parseRParen())
262         return failure();
263       segments[numThreadsClausePos] = 1;
264     } else if (keyword == "private") {
265       // Fail if there was already another private clause.
266       if (segments[privateClausePos])
267         return allowedOnce(parser, "private", opName);
268       if (parseOperandAndTypeList(parser, privates, privateTypes))
269         return failure();
270       segments[privateClausePos] = privates.size();
271     } else if (keyword == "firstprivate") {
272       // Fail if there was already another firstprivate clause.
273       if (segments[firstprivateClausePos])
274         return allowedOnce(parser, "firstprivate", opName);
275       if (parseOperandAndTypeList(parser, firstprivates, firstprivateTypes))
276         return failure();
277       segments[firstprivateClausePos] = firstprivates.size();
278     } else if (keyword == "shared") {
279       // Fail if there was already another shared clause.
280       if (segments[sharedClausePos])
281         return allowedOnce(parser, "shared", opName);
282       if (parseOperandAndTypeList(parser, shareds, sharedTypes))
283         return failure();
284       segments[sharedClausePos] = shareds.size();
285     } else if (keyword == "copyin") {
286       // Fail if there was already another copyin clause.
287       if (segments[copyinClausePos])
288         return allowedOnce(parser, "copyin", opName);
289       if (parseOperandAndTypeList(parser, copyins, copyinTypes))
290         return failure();
291       segments[copyinClausePos] = copyins.size();
292     } else if (keyword == "allocate") {
293       // Fail if there was already another allocate clause.
294       if (segments[allocateClausePos])
295         return allowedOnce(parser, "allocate", opName);
296       if (parseAllocateAndAllocator(parser, allocates, allocateTypes,
297                                     allocators, allocatorTypes))
298         return failure();
299       segments[allocateClausePos] = allocates.size();
300       segments[allocatorPos] = allocators.size();
301     } else if (keyword == "default") {
302       // Fail if there was already another default clause.
303       if (defaultVal)
304         return allowedOnce(parser, "default", opName);
305       defaultVal = true;
306       StringRef defval;
307       if (parser.parseLParen() || parser.parseKeyword(&defval) ||
308           parser.parseRParen())
309         return failure();
310       // The def prefix is required for the attribute as "private" is a keyword
311       // in C++.
312       auto attr = parser.getBuilder().getStringAttr("def" + defval);
313       result.addAttribute("default_val", attr);
314     } else if (keyword == "proc_bind") {
315       // Fail if there was already another proc_bind clause.
316       if (procBind)
317         return allowedOnce(parser, "proc_bind", opName);
318       procBind = true;
319       StringRef bind;
320       if (parser.parseLParen() || parser.parseKeyword(&bind) ||
321           parser.parseRParen())
322         return failure();
323       auto attr = parser.getBuilder().getStringAttr(bind);
324       result.addAttribute("proc_bind_val", attr);
325     } else {
326       return parser.emitError(parser.getNameLoc())
327              << keyword << " is not a valid clause for the " << opName
328              << " operation";
329     }
330   }
331 
332   // Add if parameter.
333   if (segments[ifClausePos] &&
334       parser.resolveOperand(ifCond.first, ifCond.second, result.operands))
335     return failure();
336 
337   // Add num_threads parameter.
338   if (segments[numThreadsClausePos] &&
339       parser.resolveOperand(numThreads.first, numThreads.second,
340                             result.operands))
341     return failure();
342 
343   // Add private parameters.
344   if (segments[privateClausePos] &&
345       parser.resolveOperands(privates, privateTypes, privates[0].location,
346                              result.operands))
347     return failure();
348 
349   // Add firstprivate parameters.
350   if (segments[firstprivateClausePos] &&
351       parser.resolveOperands(firstprivates, firstprivateTypes,
352                              firstprivates[0].location, result.operands))
353     return failure();
354 
355   // Add shared parameters.
356   if (segments[sharedClausePos] &&
357       parser.resolveOperands(shareds, sharedTypes, shareds[0].location,
358                              result.operands))
359     return failure();
360 
361   // Add copyin parameters.
362   if (segments[copyinClausePos] &&
363       parser.resolveOperands(copyins, copyinTypes, copyins[0].location,
364                              result.operands))
365     return failure();
366 
367   // Add allocate parameters.
368   if (segments[allocateClausePos] &&
369       parser.resolveOperands(allocates, allocateTypes, allocates[0].location,
370                              result.operands))
371     return failure();
372 
373   // Add allocator parameters.
374   if (segments[allocatorPos] &&
375       parser.resolveOperands(allocators, allocatorTypes, allocators[0].location,
376                              result.operands))
377     return failure();
378 
379   result.addAttribute("operand_segment_sizes",
380                       parser.getBuilder().getI32VectorAttr(segments));
381 
382   Region *body = result.addRegion();
383   SmallVector<OpAsmParser::OperandType, 4> regionArgs;
384   SmallVector<Type, 4> regionArgTypes;
385   if (parser.parseRegion(*body, regionArgs, regionArgTypes))
386     return failure();
387   return success();
388 }
389 
390 /// linear ::= `linear` `(` linear-list `)`
391 /// linear-list := linear-val | linear-val linear-list
392 /// linear-val := ssa-id-and-type `=` ssa-id-and-type
393 static ParseResult
394 parseLinearClause(OpAsmParser &parser,
395                   SmallVectorImpl<OpAsmParser::OperandType> &vars,
396                   SmallVectorImpl<Type> &types,
397                   SmallVectorImpl<OpAsmParser::OperandType> &stepVars) {
398   if (parser.parseLParen())
399     return failure();
400 
401   do {
402     OpAsmParser::OperandType var;
403     Type type;
404     OpAsmParser::OperandType stepVar;
405     if (parser.parseOperand(var) || parser.parseEqual() ||
406         parser.parseOperand(stepVar) || parser.parseColonType(type))
407       return failure();
408 
409     vars.push_back(var);
410     types.push_back(type);
411     stepVars.push_back(stepVar);
412   } while (succeeded(parser.parseOptionalComma()));
413 
414   if (parser.parseRParen())
415     return failure();
416 
417   return success();
418 }
419 
420 /// schedule ::= `schedule` `(` sched-list `)`
421 /// sched-list ::= sched-val | sched-val sched-list
422 /// sched-val ::= sched-with-chunk | sched-wo-chunk
423 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
424 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
425 /// sched-wo-chunk ::=  `auto` | `runtime`
426 static ParseResult
427 parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule,
428                     Optional<OpAsmParser::OperandType> &chunkSize) {
429   if (parser.parseLParen())
430     return failure();
431 
432   StringRef keyword;
433   if (parser.parseKeyword(&keyword))
434     return failure();
435 
436   schedule = keyword;
437   if (keyword == "static" || keyword == "dynamic" || keyword == "guided") {
438     if (succeeded(parser.parseOptionalEqual())) {
439       chunkSize = OpAsmParser::OperandType{};
440       if (parser.parseOperand(*chunkSize))
441         return failure();
442     } else {
443       chunkSize = llvm::NoneType::None;
444     }
445   } else if (keyword == "auto" || keyword == "runtime") {
446     chunkSize = llvm::NoneType::None;
447   } else {
448     return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
449   }
450 
451   if (parser.parseRParen())
452     return failure();
453 
454   return success();
455 }
456 
457 /// reduction-init ::= `reduction` `(` reduction-entry-list `)`
458 /// reduction-entry-list ::= reduction-entry
459 ///                        | reduction-entry-list `,` reduction-entry
460 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type
461 static ParseResult
462 parseReductionVarList(OpAsmParser &parser,
463                       SmallVectorImpl<SymbolRefAttr> &symbols,
464                       SmallVectorImpl<OpAsmParser::OperandType> &operands,
465                       SmallVectorImpl<Type> &types) {
466   if (failed(parser.parseLParen()))
467     return failure();
468 
469   do {
470     if (parser.parseAttribute(symbols.emplace_back()) || parser.parseArrow() ||
471         parser.parseOperand(operands.emplace_back()) ||
472         parser.parseColonType(types.emplace_back()))
473       return failure();
474   } while (succeeded(parser.parseOptionalComma()));
475   return parser.parseRParen();
476 }
477 
478 /// Parses an OpenMP Workshare Loop operation
479 ///
480 /// operation ::= `omp.wsloop` loop-control clause-list
481 /// loop-control ::= `(` ssa-id-list `)` `:` type `=`  loop-bounds
482 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` steps
483 /// steps := `step` `(`ssa-id-list`)`
484 /// clause-list ::= clause | empty | clause-list
485 /// clause ::= private | firstprivate | lastprivate | linear | schedule |
486 //             collapse | nowait | ordered | order | inclusive
487 /// private ::= `private` `(` ssa-id-and-type-list `)`
488 /// firstprivate ::= `firstprivate` `(` ssa-id-and-type-list `)`
489 /// lastprivate ::= `lastprivate` `(` ssa-id-and-type-list `)`
490 /// linear ::= `linear` `(` linear-list `)`
491 /// schedule ::= `schedule` `(` sched-list `)`
492 /// collapse ::= `collapse` `(` ssa-id-and-type `)`
493 /// nowait ::= `nowait`
494 /// ordered ::= `ordered` `(` ssa-id-and-type `)`
495 /// order ::= `order` `(` `concurrent` `)`
496 /// inclusive ::= `inclusive`
497 ///
498 static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) {
499   Type loopVarType;
500   int numIVs;
501 
502   // Parse an opening `(` followed by induction variables followed by `)`
503   SmallVector<OpAsmParser::OperandType> ivs;
504   if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
505                                      OpAsmParser::Delimiter::Paren))
506     return failure();
507 
508   numIVs = static_cast<int>(ivs.size());
509 
510   if (parser.parseColonType(loopVarType))
511     return failure();
512 
513   // Parse loop bounds.
514   SmallVector<OpAsmParser::OperandType> lower;
515   if (parser.parseEqual() ||
516       parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) ||
517       parser.resolveOperands(lower, loopVarType, result.operands))
518     return failure();
519 
520   SmallVector<OpAsmParser::OperandType> upper;
521   if (parser.parseKeyword("to") ||
522       parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) ||
523       parser.resolveOperands(upper, loopVarType, result.operands))
524     return failure();
525 
526   // Parse step values.
527   SmallVector<OpAsmParser::OperandType> steps;
528   if (parser.parseKeyword("step") ||
529       parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) ||
530       parser.resolveOperands(steps, loopVarType, result.operands))
531     return failure();
532 
533   SmallVector<OpAsmParser::OperandType> privates;
534   SmallVector<Type> privateTypes;
535   SmallVector<OpAsmParser::OperandType> firstprivates;
536   SmallVector<Type> firstprivateTypes;
537   SmallVector<OpAsmParser::OperandType> lastprivates;
538   SmallVector<Type> lastprivateTypes;
539   SmallVector<OpAsmParser::OperandType> linears;
540   SmallVector<Type> linearTypes;
541   SmallVector<OpAsmParser::OperandType> linearSteps;
542   SmallVector<SymbolRefAttr> reductionSymbols;
543   SmallVector<OpAsmParser::OperandType> reductionVars;
544   SmallVector<Type> reductionVarTypes;
545   SmallString<8> schedule;
546   Optional<OpAsmParser::OperandType> scheduleChunkSize;
547 
548   const StringRef opName = result.name.getStringRef();
549   StringRef keyword;
550 
551   enum SegmentPos {
552     lbPos = 0,
553     ubPos,
554     stepPos,
555     privateClausePos,
556     firstprivateClausePos,
557     lastprivateClausePos,
558     linearClausePos,
559     linearStepPos,
560     reductionVarPos,
561     scheduleClausePos,
562   };
563   std::array<int, 10> segments{numIVs, numIVs, numIVs, 0, 0, 0, 0, 0, 0, 0};
564 
565   while (succeeded(parser.parseOptionalKeyword(&keyword))) {
566     if (keyword == "private") {
567       if (segments[privateClausePos])
568         return allowedOnce(parser, "private", opName);
569       if (parseOperandAndTypeList(parser, privates, privateTypes))
570         return failure();
571       segments[privateClausePos] = privates.size();
572     } else if (keyword == "firstprivate") {
573       // fail if there was already another firstprivate clause
574       if (segments[firstprivateClausePos])
575         return allowedOnce(parser, "firstprivate", opName);
576       if (parseOperandAndTypeList(parser, firstprivates, firstprivateTypes))
577         return failure();
578       segments[firstprivateClausePos] = firstprivates.size();
579     } else if (keyword == "lastprivate") {
580       // fail if there was already another shared clause
581       if (segments[lastprivateClausePos])
582         return allowedOnce(parser, "lastprivate", opName);
583       if (parseOperandAndTypeList(parser, lastprivates, lastprivateTypes))
584         return failure();
585       segments[lastprivateClausePos] = lastprivates.size();
586     } else if (keyword == "linear") {
587       // fail if there was already another linear clause
588       if (segments[linearClausePos])
589         return allowedOnce(parser, "linear", opName);
590       if (parseLinearClause(parser, linears, linearTypes, linearSteps))
591         return failure();
592       segments[linearClausePos] = linears.size();
593       segments[linearStepPos] = linearSteps.size();
594     } else if (keyword == "schedule") {
595       if (!schedule.empty())
596         return allowedOnce(parser, "schedule", opName);
597       if (parseScheduleClause(parser, schedule, scheduleChunkSize))
598         return failure();
599       if (scheduleChunkSize) {
600         segments[scheduleClausePos] = 1;
601       }
602     } else if (keyword == "collapse") {
603       auto type = parser.getBuilder().getI64Type();
604       mlir::IntegerAttr attr;
605       if (parser.parseLParen() || parser.parseAttribute(attr, type) ||
606           parser.parseRParen())
607         return failure();
608       result.addAttribute("collapse_val", attr);
609     } else if (keyword == "nowait") {
610       auto attr = UnitAttr::get(parser.getBuilder().getContext());
611       result.addAttribute("nowait", attr);
612     } else if (keyword == "ordered") {
613       mlir::IntegerAttr attr;
614       if (succeeded(parser.parseOptionalLParen())) {
615         auto type = parser.getBuilder().getI64Type();
616         if (parser.parseAttribute(attr, type))
617           return failure();
618         if (parser.parseRParen())
619           return failure();
620       } else {
621         // Use 0 to represent no ordered parameter was specified
622         attr = parser.getBuilder().getI64IntegerAttr(0);
623       }
624       result.addAttribute("ordered_val", attr);
625     } else if (keyword == "order") {
626       StringRef order;
627       if (parser.parseLParen() || parser.parseKeyword(&order) ||
628           parser.parseRParen())
629         return failure();
630       auto attr = parser.getBuilder().getStringAttr(order);
631       result.addAttribute("order", attr);
632     } else if (keyword == "inclusive") {
633       auto attr = UnitAttr::get(parser.getBuilder().getContext());
634       result.addAttribute("inclusive", attr);
635     } else if (keyword == "reduction") {
636       if (segments[reductionVarPos])
637         return allowedOnce(parser, "reduction", opName);
638       if (failed(parseReductionVarList(parser, reductionSymbols, reductionVars,
639                                        reductionVarTypes)))
640         return failure();
641       segments[reductionVarPos] = reductionVars.size();
642     }
643   }
644 
645   if (segments[privateClausePos]) {
646     parser.resolveOperands(privates, privateTypes, privates[0].location,
647                            result.operands);
648   }
649 
650   if (segments[firstprivateClausePos]) {
651     parser.resolveOperands(firstprivates, firstprivateTypes,
652                            firstprivates[0].location, result.operands);
653   }
654 
655   if (segments[lastprivateClausePos]) {
656     parser.resolveOperands(lastprivates, lastprivateTypes,
657                            lastprivates[0].location, result.operands);
658   }
659 
660   if (segments[linearClausePos]) {
661     parser.resolveOperands(linears, linearTypes, linears[0].location,
662                            result.operands);
663     auto linearStepType = parser.getBuilder().getI32Type();
664     SmallVector<Type> linearStepTypes(linearSteps.size(), linearStepType);
665     parser.resolveOperands(linearSteps, linearStepTypes,
666                            linearSteps[0].location, result.operands);
667   }
668 
669   if (segments[reductionVarPos]) {
670     if (failed(parser.resolveOperands(reductionVars, reductionVarTypes,
671                                       parser.getNameLoc(), result.operands))) {
672       return failure();
673     }
674     SmallVector<Attribute> reductions(reductionSymbols.begin(),
675                                       reductionSymbols.end());
676     result.addAttribute("reductions",
677                         parser.getBuilder().getArrayAttr(reductions));
678   }
679 
680   if (!schedule.empty()) {
681     schedule[0] = llvm::toUpper(schedule[0]);
682     auto attr = parser.getBuilder().getStringAttr(schedule);
683     result.addAttribute("schedule_val", attr);
684     if (scheduleChunkSize) {
685       auto chunkSizeType = parser.getBuilder().getI32Type();
686       parser.resolveOperand(*scheduleChunkSize, chunkSizeType, result.operands);
687     }
688   }
689 
690   result.addAttribute("operand_segment_sizes",
691                       parser.getBuilder().getI32VectorAttr(segments));
692 
693   // Now parse the body.
694   Region *body = result.addRegion();
695   SmallVector<Type> ivTypes(numIVs, loopVarType);
696   SmallVector<OpAsmParser::OperandType> blockArgs(ivs);
697   if (parser.parseRegion(*body, blockArgs, ivTypes))
698     return failure();
699   return success();
700 }
701 
702 static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) {
703   auto args = op.getRegion().front().getArguments();
704   p << " (" << args << ") : " << args[0].getType() << " = (" << op.lowerBound()
705     << ") to (" << op.upperBound() << ") step (" << op.step() << ")";
706 
707   // Print private, firstprivate, shared and copyin parameters
708   auto printDataVars = [&p](StringRef name, OperandRange vars) {
709     if (vars.empty())
710       return;
711 
712     p << " " << name << "(";
713     llvm::interleaveComma(
714         vars, p, [&](const Value &v) { p << v << " : " << v.getType(); });
715     p << ")";
716   };
717   printDataVars("private", op.private_vars());
718   printDataVars("firstprivate", op.firstprivate_vars());
719   printDataVars("lastprivate", op.lastprivate_vars());
720 
721   auto linearVars = op.linear_vars();
722   auto linearVarsSize = linearVars.size();
723   if (linearVarsSize) {
724     p << " "
725       << "linear"
726       << "(";
727     for (unsigned i = 0; i < linearVarsSize; ++i) {
728       std::string separator = i == linearVarsSize - 1 ? ")" : ", ";
729       p << linearVars[i];
730       if (op.linear_step_vars().size() > i)
731         p << " = " << op.linear_step_vars()[i];
732       p << " : " << linearVars[i].getType() << separator;
733     }
734   }
735 
736   if (auto sched = op.schedule_val()) {
737     auto schedLower = sched->lower();
738     p << " schedule(" << schedLower;
739     if (auto chunk = op.schedule_chunk_var()) {
740       p << " = " << chunk;
741     }
742     p << ")";
743   }
744 
745   if (auto collapse = op.collapse_val())
746     p << " collapse(" << collapse << ")";
747 
748   if (op.nowait())
749     p << " nowait";
750 
751   if (auto ordered = op.ordered_val()) {
752     p << " ordered(" << ordered << ")";
753   }
754 
755   if (!op.reduction_vars().empty()) {
756     p << " reduction(";
757     for (unsigned i = 0, e = op.getNumReductionVars(); i < e; ++i) {
758       if (i != 0)
759         p << ", ";
760       p << (*op.reductions())[i] << " -> " << op.reduction_vars()[i] << " : "
761         << op.reduction_vars()[i].getType();
762     }
763     p << ")";
764   }
765 
766   if (op.inclusive()) {
767     p << " inclusive";
768   }
769 
770   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
771 }
772 
773 //===----------------------------------------------------------------------===//
774 // ReductionOp
775 //===----------------------------------------------------------------------===//
776 
777 static ParseResult parseAtomicReductionRegion(OpAsmParser &parser,
778                                               Region &region) {
779   if (parser.parseOptionalKeyword("atomic"))
780     return success();
781   return parser.parseRegion(region);
782 }
783 
784 static void printAtomicReductionRegion(OpAsmPrinter &printer,
785                                        ReductionDeclareOp op, Region &region) {
786   if (region.empty())
787     return;
788   printer << "atomic ";
789   printer.printRegion(region);
790 }
791 
792 static LogicalResult verifyReductionDeclareOp(ReductionDeclareOp op) {
793   if (op.initializerRegion().empty())
794     return op.emitOpError() << "expects non-empty initializer region";
795   Block &initializerEntryBlock = op.initializerRegion().front();
796   if (initializerEntryBlock.getNumArguments() != 1 ||
797       initializerEntryBlock.getArgument(0).getType() != op.type()) {
798     return op.emitOpError() << "expects initializer region with one argument "
799                                "of the reduction type";
800   }
801 
802   for (YieldOp yieldOp : op.initializerRegion().getOps<YieldOp>()) {
803     if (yieldOp.results().size() != 1 ||
804         yieldOp.results().getTypes()[0] != op.type())
805       return op.emitOpError() << "expects initializer region to yield a value "
806                                  "of the reduction type";
807   }
808 
809   if (op.reductionRegion().empty())
810     return op.emitOpError() << "expects non-empty reduction region";
811   Block &reductionEntryBlock = op.reductionRegion().front();
812   if (reductionEntryBlock.getNumArguments() != 2 ||
813       reductionEntryBlock.getArgumentTypes()[0] !=
814           reductionEntryBlock.getArgumentTypes()[1] ||
815       reductionEntryBlock.getArgumentTypes()[0] != op.type())
816     return op.emitOpError() << "expects reduction region with two arguments of "
817                                "the reduction type";
818   for (YieldOp yieldOp : op.reductionRegion().getOps<YieldOp>()) {
819     if (yieldOp.results().size() != 1 ||
820         yieldOp.results().getTypes()[0] != op.type())
821       return op.emitOpError() << "expects reduction region to yield a value "
822                                  "of the reduction type";
823   }
824 
825   if (op.atomicReductionRegion().empty())
826     return success();
827 
828   Block &atomicReductionEntryBlock = op.atomicReductionRegion().front();
829   if (atomicReductionEntryBlock.getNumArguments() != 2 ||
830       atomicReductionEntryBlock.getArgumentTypes()[0] !=
831           atomicReductionEntryBlock.getArgumentTypes()[1])
832     return op.emitOpError() << "expects atomic reduction region with two "
833                                "arguments of the same type";
834   auto ptrType = atomicReductionEntryBlock.getArgumentTypes()[0]
835                      .dyn_cast<PointerLikeType>();
836   if (!ptrType || ptrType.getElementType() != op.type())
837     return op.emitOpError() << "expects atomic reduction region arguments to "
838                                "be accumulators containing the reduction type";
839   return success();
840 }
841 
842 static LogicalResult verifyReductionOp(ReductionOp op) {
843   // TODO: generalize this to an op interface when there is more than one op
844   // that supports reductions.
845   auto container = op->getParentOfType<WsLoopOp>();
846   for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i)
847     if (container.reduction_vars()[i] == op.accumulator())
848       return success();
849 
850   return op.emitOpError() << "the accumulator is not used by the parent";
851 }
852 
853 //===----------------------------------------------------------------------===//
854 // WsLoopOp
855 //===----------------------------------------------------------------------===//
856 
857 void WsLoopOp::build(OpBuilder &builder, OperationState &state,
858                      ValueRange lowerBound, ValueRange upperBound,
859                      ValueRange step, ArrayRef<NamedAttribute> attributes) {
860   build(builder, state, TypeRange(), lowerBound, upperBound, step,
861         /*private_vars=*/ValueRange(),
862         /*firstprivate_vars=*/ValueRange(), /*lastprivate_vars=*/ValueRange(),
863         /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
864         /*reduction_vars=*/ValueRange(), /*schedule_val=*/nullptr,
865         /*schedule_chunk_var=*/nullptr, /*collapse_val=*/nullptr,
866         /*nowait=*/nullptr, /*ordered_val=*/nullptr, /*order_val=*/nullptr,
867         /*inclusive=*/nullptr, /*buildBody=*/false);
868   state.addAttributes(attributes);
869 }
870 
871 void WsLoopOp::build(OpBuilder &, OperationState &state, TypeRange resultTypes,
872                      ValueRange operands, ArrayRef<NamedAttribute> attributes) {
873   state.addOperands(operands);
874   state.addAttributes(attributes);
875   (void)state.addRegion();
876   assert(resultTypes.empty() && "mismatched number of return types");
877   state.addTypes(resultTypes);
878 }
879 
880 void WsLoopOp::build(OpBuilder &builder, OperationState &result,
881                      TypeRange typeRange, ValueRange lowerBounds,
882                      ValueRange upperBounds, ValueRange steps,
883                      ValueRange privateVars, ValueRange firstprivateVars,
884                      ValueRange lastprivateVars, ValueRange linearVars,
885                      ValueRange linearStepVars, ValueRange reductionVars,
886                      StringAttr scheduleVal, Value scheduleChunkVar,
887                      IntegerAttr collapseVal, UnitAttr nowait,
888                      IntegerAttr orderedVal, StringAttr orderVal,
889                      UnitAttr inclusive, bool buildBody) {
890   result.addOperands(lowerBounds);
891   result.addOperands(upperBounds);
892   result.addOperands(steps);
893   result.addOperands(privateVars);
894   result.addOperands(firstprivateVars);
895   result.addOperands(linearVars);
896   result.addOperands(linearStepVars);
897   if (scheduleChunkVar)
898     result.addOperands(scheduleChunkVar);
899 
900   if (scheduleVal)
901     result.addAttribute("schedule_val", scheduleVal);
902   if (collapseVal)
903     result.addAttribute("collapse_val", collapseVal);
904   if (nowait)
905     result.addAttribute("nowait", nowait);
906   if (orderedVal)
907     result.addAttribute("ordered_val", orderedVal);
908   if (orderVal)
909     result.addAttribute("order", orderVal);
910   if (inclusive)
911     result.addAttribute("inclusive", inclusive);
912   result.addAttribute(
913       WsLoopOp::getOperandSegmentSizeAttr(),
914       builder.getI32VectorAttr(
915           {static_cast<int32_t>(lowerBounds.size()),
916            static_cast<int32_t>(upperBounds.size()),
917            static_cast<int32_t>(steps.size()),
918            static_cast<int32_t>(privateVars.size()),
919            static_cast<int32_t>(firstprivateVars.size()),
920            static_cast<int32_t>(lastprivateVars.size()),
921            static_cast<int32_t>(linearVars.size()),
922            static_cast<int32_t>(linearStepVars.size()),
923            static_cast<int32_t>(reductionVars.size()),
924            static_cast<int32_t>(scheduleChunkVar != nullptr ? 1 : 0)}));
925 
926   Region *bodyRegion = result.addRegion();
927   if (buildBody) {
928     OpBuilder::InsertionGuard guard(builder);
929     unsigned numIVs = steps.size();
930     SmallVector<Type, 8> argTypes(numIVs, steps.getType().front());
931     builder.createBlock(bodyRegion, {}, argTypes);
932   }
933 }
934 
935 static LogicalResult verifyWsLoopOp(WsLoopOp op) {
936   if (op.getNumReductionVars() != 0) {
937     if (!op.reductions() ||
938         op.reductions()->size() != op.getNumReductionVars()) {
939       return op.emitOpError() << "expected as many reduction symbol references "
940                                  "as reduction variables";
941     }
942   } else {
943     if (op.reductions())
944       return op.emitOpError() << "unexpected reduction symbol references";
945     return success();
946   }
947 
948   DenseSet<Value> accumulators;
949   for (auto args : llvm::zip(op.reduction_vars(), *op.reductions())) {
950     Value accum = std::get<0>(args);
951     if (!accumulators.insert(accum).second) {
952       return op.emitOpError() << "accumulator variable used more than once";
953     }
954     Type varType = accum.getType().cast<PointerLikeType>();
955     auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>();
956     auto decl =
957         SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef);
958     if (!decl) {
959       return op.emitOpError() << "expected symbol reference " << symbolRef
960                               << " to point to a reduction declaration";
961     }
962 
963     if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType) {
964       return op.emitOpError()
965              << "expected accumulator (" << varType
966              << ") to be the same type as reduction declaration ("
967              << decl.getAccumulatorType() << ")";
968     }
969   }
970 
971   return success();
972 }
973 
974 static LogicalResult verifyCriticalOp(CriticalOp op) {
975   if (!op.name().hasValue() && op.hint().hasValue() &&
976       (op.hint().getValue() != SyncHintKind::none))
977     return op.emitOpError() << "must specify a name unless the effect is as if "
978                                "hint(none) is specified";
979 
980   if (op.nameAttr()) {
981     auto symbolRef = op.nameAttr().cast<SymbolRefAttr>();
982     auto decl =
983         SymbolTable::lookupNearestSymbolFrom<CriticalDeclareOp>(op, symbolRef);
984     if (!decl) {
985       return op.emitOpError() << "expected symbol reference " << symbolRef
986                               << " to point to a critical declaration";
987     }
988   }
989 
990   return success();
991 }
992 
993 #define GET_OP_CLASSES
994 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
995