1 //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the OpenMP dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/IR/Attributes.h"
16 #include "mlir/IR/OpImplementation.h"
17 #include "mlir/IR/OperationSupport.h"
18 
19 #include "llvm/ADT/SmallString.h"
20 #include "llvm/ADT/StringExtras.h"
21 #include "llvm/ADT/StringRef.h"
22 #include "llvm/ADT/StringSwitch.h"
23 #include <cstddef>
24 
25 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
26 
27 using namespace mlir;
28 using namespace mlir::omp;
29 
30 void OpenMPDialect::initialize() {
31   addOperations<
32 #define GET_OP_LIST
33 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
34       >();
35 }
36 
37 //===----------------------------------------------------------------------===//
38 // ParallelOp
39 //===----------------------------------------------------------------------===//
40 
41 void ParallelOp::build(OpBuilder &builder, OperationState &state,
42                        ArrayRef<NamedAttribute> attributes) {
43   ParallelOp::build(
44       builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
45       /*default_val=*/nullptr, /*private_vars=*/ValueRange(),
46       /*firstprivate_vars=*/ValueRange(), /*shared_vars=*/ValueRange(),
47       /*copyin_vars=*/ValueRange(), /*allocate_vars=*/ValueRange(),
48       /*allocators_vars=*/ValueRange(), /*proc_bind_val=*/nullptr);
49   state.addAttributes(attributes);
50 }
51 
52 /// Parse a list of operands with types.
53 ///
54 /// operand-and-type-list ::= `(` ssa-id-and-type-list `)`
55 /// ssa-id-and-type-list ::= ssa-id-and-type |
56 ///                          ssa-id-and-type `,` ssa-id-and-type-list
57 /// ssa-id-and-type ::= ssa-id `:` type
58 static ParseResult
59 parseOperandAndTypeList(OpAsmParser &parser,
60                         SmallVectorImpl<OpAsmParser::OperandType> &operands,
61                         SmallVectorImpl<Type> &types) {
62   if (parser.parseLParen())
63     return failure();
64 
65   do {
66     OpAsmParser::OperandType operand;
67     Type type;
68     if (parser.parseOperand(operand) || parser.parseColonType(type))
69       return failure();
70     operands.push_back(operand);
71     types.push_back(type);
72   } while (succeeded(parser.parseOptionalComma()));
73 
74   if (parser.parseRParen())
75     return failure();
76 
77   return success();
78 }
79 
80 /// Parse an allocate clause with allocators and a list of operands with types.
81 ///
82 /// operand-and-type-list ::= `(` allocate-operand-list `)`
83 /// allocate-operand-list :: = allocate-operand |
84 ///                            allocator-operand `,` allocate-operand-list
85 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
86 /// ssa-id-and-type ::= ssa-id `:` type
87 static ParseResult parseAllocateAndAllocator(
88     OpAsmParser &parser,
89     SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocate,
90     SmallVectorImpl<Type> &typesAllocate,
91     SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator,
92     SmallVectorImpl<Type> &typesAllocator) {
93   if (parser.parseLParen())
94     return failure();
95 
96   do {
97     OpAsmParser::OperandType operand;
98     Type type;
99 
100     if (parser.parseOperand(operand) || parser.parseColonType(type))
101       return failure();
102     operandsAllocator.push_back(operand);
103     typesAllocator.push_back(type);
104     if (parser.parseArrow())
105       return failure();
106     if (parser.parseOperand(operand) || parser.parseColonType(type))
107       return failure();
108 
109     operandsAllocate.push_back(operand);
110     typesAllocate.push_back(type);
111   } while (succeeded(parser.parseOptionalComma()));
112 
113   if (parser.parseRParen())
114     return failure();
115 
116   return success();
117 }
118 
119 static LogicalResult verifyParallelOp(ParallelOp op) {
120   if (op.allocate_vars().size() != op.allocators_vars().size())
121     return op.emitError(
122         "expected equal sizes for allocate and allocator variables");
123   return success();
124 }
125 
126 static void printParallelOp(OpAsmPrinter &p, ParallelOp op) {
127   p << "omp.parallel";
128 
129   if (auto ifCond = op.if_expr_var())
130     p << " if(" << ifCond << " : " << ifCond.getType() << ")";
131 
132   if (auto threads = op.num_threads_var())
133     p << " num_threads(" << threads << " : " << threads.getType() << ")";
134 
135   // Print private, firstprivate, shared and copyin parameters
136   auto printDataVars = [&p](StringRef name, OperandRange vars) {
137     if (vars.size()) {
138       p << " " << name << "(";
139       for (unsigned i = 0; i < vars.size(); ++i) {
140         std::string separator = i == vars.size() - 1 ? ")" : ", ";
141         p << vars[i] << " : " << vars[i].getType() << separator;
142       }
143     }
144   };
145 
146   // Print allocator and allocate parameters
147   auto printAllocateAndAllocator = [&p](OperandRange varsAllocate,
148                                         OperandRange varsAllocator) {
149     if (varsAllocate.empty())
150       return;
151 
152     p << " allocate(";
153     for (unsigned i = 0; i < varsAllocate.size(); ++i) {
154       std::string separator = i == varsAllocate.size() - 1 ? ")" : ", ";
155       p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> ";
156       p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator;
157     }
158   };
159 
160   printDataVars("private", op.private_vars());
161   printDataVars("firstprivate", op.firstprivate_vars());
162   printDataVars("shared", op.shared_vars());
163   printDataVars("copyin", op.copyin_vars());
164   printAllocateAndAllocator(op.allocate_vars(), op.allocators_vars());
165 
166   if (auto def = op.default_val())
167     p << " default(" << def->drop_front(3) << ")";
168 
169   if (auto bind = op.proc_bind_val())
170     p << " proc_bind(" << bind << ")";
171 
172   p.printRegion(op.getRegion());
173 }
174 
175 /// Emit an error if the same clause is present more than once on an operation.
176 static ParseResult allowedOnce(OpAsmParser &parser, StringRef clause,
177                                StringRef operation) {
178   return parser.emitError(parser.getNameLoc())
179          << " at most one " << clause << " clause can appear on the "
180          << operation << " operation";
181 }
182 
183 /// Parses a parallel operation.
184 ///
185 /// operation ::= `omp.parallel` clause-list
186 /// clause-list ::= clause | clause clause-list
187 /// clause ::= if | numThreads | private | firstprivate | shared | copyin |
188 ///            default | procBind
189 /// if ::= `if` `(` ssa-id `)`
190 /// numThreads ::= `num_threads` `(` ssa-id-and-type `)`
191 /// private ::= `private` operand-and-type-list
192 /// firstprivate ::= `firstprivate` operand-and-type-list
193 /// shared ::= `shared` operand-and-type-list
194 /// copyin ::= `copyin` operand-and-type-list
195 /// allocate ::= `allocate` operand-and-type `->` operand-and-type-list
196 /// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`)
197 /// procBind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)`
198 ///
199 /// Note that each clause can only appear once in the clase-list.
200 static ParseResult parseParallelOp(OpAsmParser &parser,
201                                    OperationState &result) {
202   std::pair<OpAsmParser::OperandType, Type> ifCond;
203   std::pair<OpAsmParser::OperandType, Type> numThreads;
204   SmallVector<OpAsmParser::OperandType, 4> privates;
205   SmallVector<Type, 4> privateTypes;
206   SmallVector<OpAsmParser::OperandType, 4> firstprivates;
207   SmallVector<Type, 4> firstprivateTypes;
208   SmallVector<OpAsmParser::OperandType, 4> shareds;
209   SmallVector<Type, 4> sharedTypes;
210   SmallVector<OpAsmParser::OperandType, 4> copyins;
211   SmallVector<Type, 4> copyinTypes;
212   SmallVector<OpAsmParser::OperandType, 4> allocates;
213   SmallVector<Type, 4> allocateTypes;
214   SmallVector<OpAsmParser::OperandType, 4> allocators;
215   SmallVector<Type, 4> allocatorTypes;
216   std::array<int, 8> segments{0, 0, 0, 0, 0, 0, 0, 0};
217   StringRef keyword;
218   bool defaultVal = false;
219   bool procBind = false;
220 
221   const int ifClausePos = 0;
222   const int numThreadsClausePos = 1;
223   const int privateClausePos = 2;
224   const int firstprivateClausePos = 3;
225   const int sharedClausePos = 4;
226   const int copyinClausePos = 5;
227   const int allocateClausePos = 6;
228   const int allocatorPos = 7;
229   const StringRef opName = result.name.getStringRef();
230 
231   while (succeeded(parser.parseOptionalKeyword(&keyword))) {
232     if (keyword == "if") {
233       // Fail if there was already another if condition.
234       if (segments[ifClausePos])
235         return allowedOnce(parser, "if", opName);
236       if (parser.parseLParen() || parser.parseOperand(ifCond.first) ||
237           parser.parseColonType(ifCond.second) || parser.parseRParen())
238         return failure();
239       segments[ifClausePos] = 1;
240     } else if (keyword == "num_threads") {
241       // Fail if there was already another num_threads clause.
242       if (segments[numThreadsClausePos])
243         return allowedOnce(parser, "num_threads", opName);
244       if (parser.parseLParen() || parser.parseOperand(numThreads.first) ||
245           parser.parseColonType(numThreads.second) || parser.parseRParen())
246         return failure();
247       segments[numThreadsClausePos] = 1;
248     } else if (keyword == "private") {
249       // Fail if there was already another private clause.
250       if (segments[privateClausePos])
251         return allowedOnce(parser, "private", opName);
252       if (parseOperandAndTypeList(parser, privates, privateTypes))
253         return failure();
254       segments[privateClausePos] = privates.size();
255     } else if (keyword == "firstprivate") {
256       // Fail if there was already another firstprivate clause.
257       if (segments[firstprivateClausePos])
258         return allowedOnce(parser, "firstprivate", opName);
259       if (parseOperandAndTypeList(parser, firstprivates, firstprivateTypes))
260         return failure();
261       segments[firstprivateClausePos] = firstprivates.size();
262     } else if (keyword == "shared") {
263       // Fail if there was already another shared clause.
264       if (segments[sharedClausePos])
265         return allowedOnce(parser, "shared", opName);
266       if (parseOperandAndTypeList(parser, shareds, sharedTypes))
267         return failure();
268       segments[sharedClausePos] = shareds.size();
269     } else if (keyword == "copyin") {
270       // Fail if there was already another copyin clause.
271       if (segments[copyinClausePos])
272         return allowedOnce(parser, "copyin", opName);
273       if (parseOperandAndTypeList(parser, copyins, copyinTypes))
274         return failure();
275       segments[copyinClausePos] = copyins.size();
276     } else if (keyword == "allocate") {
277       // Fail if there was already another allocate clause.
278       if (segments[allocateClausePos])
279         return allowedOnce(parser, "allocate", opName);
280       if (parseAllocateAndAllocator(parser, allocates, allocateTypes,
281                                     allocators, allocatorTypes))
282         return failure();
283       segments[allocateClausePos] = allocates.size();
284       segments[allocatorPos] = allocators.size();
285     } else if (keyword == "default") {
286       // Fail if there was already another default clause.
287       if (defaultVal)
288         return allowedOnce(parser, "default", opName);
289       defaultVal = true;
290       StringRef defval;
291       if (parser.parseLParen() || parser.parseKeyword(&defval) ||
292           parser.parseRParen())
293         return failure();
294       // The def prefix is required for the attribute as "private" is a keyword
295       // in C++.
296       auto attr = parser.getBuilder().getStringAttr("def" + defval);
297       result.addAttribute("default_val", attr);
298     } else if (keyword == "proc_bind") {
299       // Fail if there was already another proc_bind clause.
300       if (procBind)
301         return allowedOnce(parser, "proc_bind", opName);
302       procBind = true;
303       StringRef bind;
304       if (parser.parseLParen() || parser.parseKeyword(&bind) ||
305           parser.parseRParen())
306         return failure();
307       auto attr = parser.getBuilder().getStringAttr(bind);
308       result.addAttribute("proc_bind_val", attr);
309     } else {
310       return parser.emitError(parser.getNameLoc())
311              << keyword << " is not a valid clause for the " << opName
312              << " operation";
313     }
314   }
315 
316   // Add if parameter.
317   if (segments[ifClausePos] &&
318       parser.resolveOperand(ifCond.first, ifCond.second, result.operands))
319     return failure();
320 
321   // Add num_threads parameter.
322   if (segments[numThreadsClausePos] &&
323       parser.resolveOperand(numThreads.first, numThreads.second,
324                             result.operands))
325     return failure();
326 
327   // Add private parameters.
328   if (segments[privateClausePos] &&
329       parser.resolveOperands(privates, privateTypes, privates[0].location,
330                              result.operands))
331     return failure();
332 
333   // Add firstprivate parameters.
334   if (segments[firstprivateClausePos] &&
335       parser.resolveOperands(firstprivates, firstprivateTypes,
336                              firstprivates[0].location, result.operands))
337     return failure();
338 
339   // Add shared parameters.
340   if (segments[sharedClausePos] &&
341       parser.resolveOperands(shareds, sharedTypes, shareds[0].location,
342                              result.operands))
343     return failure();
344 
345   // Add copyin parameters.
346   if (segments[copyinClausePos] &&
347       parser.resolveOperands(copyins, copyinTypes, copyins[0].location,
348                              result.operands))
349     return failure();
350 
351   // Add allocate parameters.
352   if (segments[allocateClausePos] &&
353       parser.resolveOperands(allocates, allocateTypes, allocates[0].location,
354                              result.operands))
355     return failure();
356 
357   // Add allocator parameters.
358   if (segments[allocatorPos] &&
359       parser.resolveOperands(allocators, allocatorTypes, allocators[0].location,
360                              result.operands))
361     return failure();
362 
363   result.addAttribute("operand_segment_sizes",
364                       parser.getBuilder().getI32VectorAttr(segments));
365 
366   Region *body = result.addRegion();
367   SmallVector<OpAsmParser::OperandType, 4> regionArgs;
368   SmallVector<Type, 4> regionArgTypes;
369   if (parser.parseRegion(*body, regionArgs, regionArgTypes))
370     return failure();
371   return success();
372 }
373 
374 /// linear ::= `linear` `(` linear-list `)`
375 /// linear-list := linear-val | linear-val linear-list
376 /// linear-val := ssa-id-and-type `=` ssa-id-and-type
377 static ParseResult
378 parseLinearClause(OpAsmParser &parser,
379                   SmallVectorImpl<OpAsmParser::OperandType> &vars,
380                   SmallVectorImpl<Type> &types,
381                   SmallVectorImpl<OpAsmParser::OperandType> &stepVars) {
382   if (parser.parseLParen())
383     return failure();
384 
385   do {
386     OpAsmParser::OperandType var;
387     Type type;
388     OpAsmParser::OperandType stepVar;
389     if (parser.parseOperand(var) || parser.parseEqual() ||
390         parser.parseOperand(stepVar) || parser.parseColonType(type))
391       return failure();
392 
393     vars.push_back(var);
394     types.push_back(type);
395     stepVars.push_back(stepVar);
396   } while (succeeded(parser.parseOptionalComma()));
397 
398   if (parser.parseRParen())
399     return failure();
400 
401   return success();
402 }
403 
404 /// schedule ::= `schedule` `(` sched-list `)`
405 /// sched-list ::= sched-val | sched-val sched-list
406 /// sched-val ::= sched-with-chunk | sched-wo-chunk
407 /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
408 /// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
409 /// sched-wo-chunk ::=  `auto` | `runtime`
410 static ParseResult
411 parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule,
412                     Optional<OpAsmParser::OperandType> &chunkSize) {
413   if (parser.parseLParen())
414     return failure();
415 
416   StringRef keyword;
417   if (parser.parseKeyword(&keyword))
418     return failure();
419 
420   schedule = keyword;
421   if (keyword == "static" || keyword == "dynamic" || keyword == "guided") {
422     if (succeeded(parser.parseOptionalEqual())) {
423       chunkSize = OpAsmParser::OperandType{};
424       if (parser.parseOperand(*chunkSize))
425         return failure();
426     } else {
427       chunkSize = llvm::NoneType::None;
428     }
429   } else if (keyword == "auto" || keyword == "runtime") {
430     chunkSize = llvm::NoneType::None;
431   } else {
432     return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
433   }
434 
435   if (parser.parseRParen())
436     return failure();
437 
438   return success();
439 }
440 
441 /// Parses an OpenMP Workshare Loop operation
442 ///
443 /// operation ::= `omp.wsloop` loop-control clause-list
444 /// loop-control ::= `(` ssa-id-list `)` `:` type `=`  loop-bounds
445 /// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` steps
446 /// steps := `step` `(`ssa-id-list`)`
447 /// clause-list ::= clause | empty | clause-list
448 /// clause ::= private | firstprivate | lastprivate | linear | schedule |
449 //             collapse | nowait | ordered | order | inclusive
450 /// private ::= `private` `(` ssa-id-and-type-list `)`
451 /// firstprivate ::= `firstprivate` `(` ssa-id-and-type-list `)`
452 /// lastprivate ::= `lastprivate` `(` ssa-id-and-type-list `)`
453 /// linear ::= `linear` `(` linear-list `)`
454 /// schedule ::= `schedule` `(` sched-list `)`
455 /// collapse ::= `collapse` `(` ssa-id-and-type `)`
456 /// nowait ::= `nowait`
457 /// ordered ::= `ordered` `(` ssa-id-and-type `)`
458 /// order ::= `order` `(` `concurrent` `)`
459 /// inclusive ::= `inclusive`
460 ///
461 static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) {
462   Type loopVarType;
463   int numIVs;
464 
465   // Parse an opening `(` followed by induction variables followed by `)`
466   SmallVector<OpAsmParser::OperandType> ivs;
467   if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
468                                      OpAsmParser::Delimiter::Paren))
469     return failure();
470 
471   numIVs = static_cast<int>(ivs.size());
472 
473   if (parser.parseColonType(loopVarType))
474     return failure();
475 
476   // Parse loop bounds.
477   SmallVector<OpAsmParser::OperandType> lower;
478   if (parser.parseEqual() ||
479       parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) ||
480       parser.resolveOperands(lower, loopVarType, result.operands))
481     return failure();
482 
483   SmallVector<OpAsmParser::OperandType> upper;
484   if (parser.parseKeyword("to") ||
485       parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) ||
486       parser.resolveOperands(upper, loopVarType, result.operands))
487     return failure();
488 
489   // Parse step values.
490   SmallVector<OpAsmParser::OperandType> steps;
491   if (parser.parseKeyword("step") ||
492       parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) ||
493       parser.resolveOperands(steps, loopVarType, result.operands))
494     return failure();
495 
496   SmallVector<OpAsmParser::OperandType> privates;
497   SmallVector<Type> privateTypes;
498   SmallVector<OpAsmParser::OperandType> firstprivates;
499   SmallVector<Type> firstprivateTypes;
500   SmallVector<OpAsmParser::OperandType> lastprivates;
501   SmallVector<Type> lastprivateTypes;
502   SmallVector<OpAsmParser::OperandType> linears;
503   SmallVector<Type> linearTypes;
504   SmallVector<OpAsmParser::OperandType> linearSteps;
505   SmallString<8> schedule;
506   Optional<OpAsmParser::OperandType> scheduleChunkSize;
507   std::array<int, 9> segments{numIVs, numIVs, numIVs, 0, 0, 0, 0, 0, 0};
508 
509   const StringRef opName = result.name.getStringRef();
510   StringRef keyword;
511 
512   enum SegmentPos {
513     lbPos = 0,
514     ubPos,
515     stepPos,
516     privateClausePos,
517     firstprivateClausePos,
518     lastprivateClausePos,
519     linearClausePos,
520     linearStepPos,
521     scheduleClausePos,
522   };
523 
524   while (succeeded(parser.parseOptionalKeyword(&keyword))) {
525     if (keyword == "private") {
526       if (segments[privateClausePos])
527         return allowedOnce(parser, "private", opName);
528       if (parseOperandAndTypeList(parser, privates, privateTypes))
529         return failure();
530       segments[privateClausePos] = privates.size();
531     } else if (keyword == "firstprivate") {
532       // fail if there was already another firstprivate clause
533       if (segments[firstprivateClausePos])
534         return allowedOnce(parser, "firstprivate", opName);
535       if (parseOperandAndTypeList(parser, firstprivates, firstprivateTypes))
536         return failure();
537       segments[firstprivateClausePos] = firstprivates.size();
538     } else if (keyword == "lastprivate") {
539       // fail if there was already another shared clause
540       if (segments[lastprivateClausePos])
541         return allowedOnce(parser, "lastprivate", opName);
542       if (parseOperandAndTypeList(parser, lastprivates, lastprivateTypes))
543         return failure();
544       segments[lastprivateClausePos] = lastprivates.size();
545     } else if (keyword == "linear") {
546       // fail if there was already another linear clause
547       if (segments[linearClausePos])
548         return allowedOnce(parser, "linear", opName);
549       if (parseLinearClause(parser, linears, linearTypes, linearSteps))
550         return failure();
551       segments[linearClausePos] = linears.size();
552       segments[linearStepPos] = linearSteps.size();
553     } else if (keyword == "schedule") {
554       if (!schedule.empty())
555         return allowedOnce(parser, "schedule", opName);
556       if (parseScheduleClause(parser, schedule, scheduleChunkSize))
557         return failure();
558       if (scheduleChunkSize) {
559         segments[scheduleClausePos] = 1;
560       }
561     } else if (keyword == "collapse") {
562       auto type = parser.getBuilder().getI64Type();
563       mlir::IntegerAttr attr;
564       if (parser.parseLParen() || parser.parseAttribute(attr, type) ||
565           parser.parseRParen())
566         return failure();
567       result.addAttribute("collapse_val", attr);
568     } else if (keyword == "nowait") {
569       auto attr = UnitAttr::get(parser.getBuilder().getContext());
570       result.addAttribute("nowait", attr);
571     } else if (keyword == "ordered") {
572       mlir::IntegerAttr attr;
573       if (succeeded(parser.parseOptionalLParen())) {
574         auto type = parser.getBuilder().getI64Type();
575         if (parser.parseAttribute(attr, type))
576           return failure();
577         if (parser.parseRParen())
578           return failure();
579       } else {
580         // Use 0 to represent no ordered parameter was specified
581         attr = parser.getBuilder().getI64IntegerAttr(0);
582       }
583       result.addAttribute("ordered_val", attr);
584     } else if (keyword == "order") {
585       StringRef order;
586       if (parser.parseLParen() || parser.parseKeyword(&order) ||
587           parser.parseRParen())
588         return failure();
589       auto attr = parser.getBuilder().getStringAttr(order);
590       result.addAttribute("order", attr);
591     } else if (keyword == "inclusive") {
592       auto attr = UnitAttr::get(parser.getBuilder().getContext());
593       result.addAttribute("inclusive", attr);
594     }
595   }
596 
597   if (segments[privateClausePos]) {
598     parser.resolveOperands(privates, privateTypes, privates[0].location,
599                            result.operands);
600   }
601 
602   if (segments[firstprivateClausePos]) {
603     parser.resolveOperands(firstprivates, firstprivateTypes,
604                            firstprivates[0].location, result.operands);
605   }
606 
607   if (segments[lastprivateClausePos]) {
608     parser.resolveOperands(lastprivates, lastprivateTypes,
609                            lastprivates[0].location, result.operands);
610   }
611 
612   if (segments[linearClausePos]) {
613     parser.resolveOperands(linears, linearTypes, linears[0].location,
614                            result.operands);
615     auto linearStepType = parser.getBuilder().getI32Type();
616     SmallVector<Type> linearStepTypes(linearSteps.size(), linearStepType);
617     parser.resolveOperands(linearSteps, linearStepTypes,
618                            linearSteps[0].location, result.operands);
619   }
620 
621   if (!schedule.empty()) {
622     schedule[0] = llvm::toUpper(schedule[0]);
623     auto attr = parser.getBuilder().getStringAttr(schedule);
624     result.addAttribute("schedule_val", attr);
625     if (scheduleChunkSize) {
626       auto chunkSizeType = parser.getBuilder().getI32Type();
627       parser.resolveOperand(*scheduleChunkSize, chunkSizeType, result.operands);
628     }
629   }
630 
631   result.addAttribute("operand_segment_sizes",
632                       parser.getBuilder().getI32VectorAttr(segments));
633 
634   // Now parse the body.
635   Region *body = result.addRegion();
636   SmallVector<Type> ivTypes(numIVs, loopVarType);
637   if (parser.parseRegion(*body, ivs, ivTypes))
638     return failure();
639   return success();
640 }
641 
642 static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) {
643   auto args = op.getRegion().front().getArguments();
644   p << op.getOperationName() << " (" << args << ") : " << args[0].getType()
645     << " = (" << op.lowerBound() << ") to (" << op.upperBound() << ") step ("
646     << op.step() << ")";
647 
648   // Print private, firstprivate, shared and copyin parameters
649   auto printDataVars = [&p](StringRef name, OperandRange vars) {
650     if (vars.empty())
651       return;
652 
653     p << " " << name << "(";
654     llvm::interleaveComma(
655         vars, p, [&](const Value &v) { p << v << " : " << v.getType(); });
656     p << ")";
657   };
658   printDataVars("private", op.private_vars());
659   printDataVars("firstprivate", op.firstprivate_vars());
660   printDataVars("lastprivate", op.lastprivate_vars());
661 
662   auto linearVars = op.linear_vars();
663   auto linearVarsSize = linearVars.size();
664   if (linearVarsSize) {
665     p << " "
666       << "linear"
667       << "(";
668     for (unsigned i = 0; i < linearVarsSize; ++i) {
669       std::string separator = i == linearVarsSize - 1 ? ")" : ", ";
670       p << linearVars[i];
671       if (op.linear_step_vars().size() > i)
672         p << " = " << op.linear_step_vars()[i];
673       p << " : " << linearVars[i].getType() << separator;
674     }
675   }
676 
677   if (auto sched = op.schedule_val()) {
678     auto schedLower = sched->lower();
679     p << " schedule(" << schedLower;
680     if (auto chunk = op.schedule_chunk_var()) {
681       p << " = " << chunk;
682     }
683     p << ")";
684   }
685 
686   if (auto collapse = op.collapse_val())
687     p << " collapse(" << collapse << ")";
688 
689   if (op.nowait())
690     p << " nowait";
691 
692   if (auto ordered = op.ordered_val()) {
693     p << " ordered(" << ordered << ")";
694   }
695 
696   if (op.inclusive()) {
697     p << " inclusive";
698   }
699 
700   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
701 }
702 
703 //===----------------------------------------------------------------------===//
704 // WsLoopOp
705 //===----------------------------------------------------------------------===//
706 
707 void WsLoopOp::build(OpBuilder &builder, OperationState &state,
708                      ValueRange lowerBound, ValueRange upperBound,
709                      ValueRange step, ArrayRef<NamedAttribute> attributes) {
710   build(builder, state, TypeRange(), lowerBound, upperBound, step,
711         /*private_vars=*/ValueRange(),
712         /*firstprivate_vars=*/ValueRange(), /*lastprivate_vars=*/ValueRange(),
713         /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
714         /*schedule_val=*/nullptr, /*schedule_chunk_var=*/nullptr,
715         /*collapse_val=*/nullptr,
716         /*nowait=*/nullptr, /*ordered_val=*/nullptr, /*order_val=*/nullptr,
717         /*inclusive=*/nullptr, /*buildBody=*/false);
718   state.addAttributes(attributes);
719 }
720 
721 void WsLoopOp::build(OpBuilder &, OperationState &state, TypeRange resultTypes,
722                      ValueRange operands, ArrayRef<NamedAttribute> attributes) {
723   state.addOperands(operands);
724   state.addAttributes(attributes);
725   (void)state.addRegion();
726   assert(resultTypes.size() == 0u && "mismatched number of return types");
727   state.addTypes(resultTypes);
728 }
729 
730 void WsLoopOp::build(OpBuilder &builder, OperationState &result,
731                      TypeRange typeRange, ValueRange lowerBounds,
732                      ValueRange upperBounds, ValueRange steps,
733                      ValueRange privateVars, ValueRange firstprivateVars,
734                      ValueRange lastprivateVars, ValueRange linearVars,
735                      ValueRange linearStepVars, StringAttr scheduleVal,
736                      Value scheduleChunkVar, IntegerAttr collapseVal,
737                      UnitAttr nowait, IntegerAttr orderedVal,
738                      StringAttr orderVal, UnitAttr inclusive, bool buildBody) {
739   result.addOperands(lowerBounds);
740   result.addOperands(upperBounds);
741   result.addOperands(steps);
742   result.addOperands(privateVars);
743   result.addOperands(firstprivateVars);
744   result.addOperands(linearVars);
745   result.addOperands(linearStepVars);
746   if (scheduleChunkVar)
747     result.addOperands(scheduleChunkVar);
748 
749   if (scheduleVal)
750     result.addAttribute("schedule_val", scheduleVal);
751   if (collapseVal)
752     result.addAttribute("collapse_val", collapseVal);
753   if (nowait)
754     result.addAttribute("nowait", nowait);
755   if (orderedVal)
756     result.addAttribute("ordered_val", orderedVal);
757   if (orderVal)
758     result.addAttribute("order", orderVal);
759   if (inclusive)
760     result.addAttribute("inclusive", inclusive);
761   result.addAttribute(
762       WsLoopOp::getOperandSegmentSizeAttr(),
763       builder.getI32VectorAttr(
764           {static_cast<int32_t>(lowerBounds.size()),
765            static_cast<int32_t>(upperBounds.size()),
766            static_cast<int32_t>(steps.size()),
767            static_cast<int32_t>(privateVars.size()),
768            static_cast<int32_t>(firstprivateVars.size()),
769            static_cast<int32_t>(lastprivateVars.size()),
770            static_cast<int32_t>(linearVars.size()),
771            static_cast<int32_t>(linearStepVars.size()),
772            static_cast<int32_t>(scheduleChunkVar != nullptr ? 1 : 0)}));
773 
774   Region *bodyRegion = result.addRegion();
775   if (buildBody) {
776     OpBuilder::InsertionGuard guard(builder);
777     unsigned numIVs = steps.size();
778     SmallVector<Type, 8> argTypes(numIVs, steps.getType().front());
779     builder.createBlock(bodyRegion, {}, argTypes);
780   }
781 }
782 
783 #define GET_OP_CLASSES
784 #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
785