1 //===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 // =============================================================================
8 
9 #include "mlir/Dialect/OpenACC/OpenACC.h"
10 #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "mlir/IR/DialectImplementation.h"
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/OpImplementation.h"
16 #include "mlir/Transforms/DialectConversion.h"
17 #include "llvm/ADT/TypeSwitch.h"
18 
19 using namespace mlir;
20 using namespace acc;
21 
22 #include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
23 
24 //===----------------------------------------------------------------------===//
25 // OpenACC operations
26 //===----------------------------------------------------------------------===//
27 
28 void OpenACCDialect::initialize() {
29   addOperations<
30 #define GET_OP_LIST
31 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
32       >();
33   addAttributes<
34 #define GET_ATTRDEF_LIST
35 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
36       >();
37 }
38 
39 template <typename StructureOp>
40 static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
41                                 unsigned nRegions = 1) {
42 
43   SmallVector<Region *, 2> regions;
44   for (unsigned i = 0; i < nRegions; ++i)
45     regions.push_back(state.addRegion());
46 
47   for (Region *region : regions) {
48     if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{}))
49       return failure();
50   }
51 
52   return success();
53 }
54 
55 static ParseResult
56 parseOperandList(OpAsmParser &parser, StringRef keyword,
57                  SmallVectorImpl<OpAsmParser::UnresolvedOperand> &args,
58                  SmallVectorImpl<Type> &argTypes, OperationState &result) {
59   if (failed(parser.parseOptionalKeyword(keyword)))
60     return success();
61 
62   if (failed(parser.parseLParen()))
63     return failure();
64 
65   // Exit early if the list is empty.
66   if (succeeded(parser.parseOptionalRParen()))
67     return success();
68 
69   do {
70     OpAsmParser::UnresolvedOperand arg;
71     Type type;
72 
73     if (parser.parseRegionArgument(arg) || parser.parseColonType(type))
74       return failure();
75 
76     args.push_back(arg);
77     argTypes.push_back(type);
78   } while (succeeded(parser.parseOptionalComma()));
79 
80   if (failed(parser.parseRParen()))
81     return failure();
82 
83   return parser.resolveOperands(args, argTypes, parser.getCurrentLocation(),
84                                 result.operands);
85 }
86 
87 static void printOperandList(Operation::operand_range operands,
88                              StringRef listName, OpAsmPrinter &printer) {
89 
90   if (!operands.empty()) {
91     printer << " " << listName << "(";
92     llvm::interleaveComma(operands, printer, [&](Value op) {
93       printer << op << ": " << op.getType();
94     });
95     printer << ")";
96   }
97 }
98 
99 static ParseResult parseOptionalOperand(OpAsmParser &parser, StringRef keyword,
100                                         OpAsmParser::UnresolvedOperand &operand,
101                                         Type type, bool &hasOptional,
102                                         OperationState &result) {
103   hasOptional = false;
104   if (succeeded(parser.parseOptionalKeyword(keyword))) {
105     hasOptional = true;
106     if (parser.parseLParen() || parser.parseOperand(operand) ||
107         parser.resolveOperand(operand, type, result.operands) ||
108         parser.parseRParen())
109       return failure();
110   }
111   return success();
112 }
113 
114 static ParseResult parseOperandAndType(OpAsmParser &parser,
115                                        OperationState &result) {
116   OpAsmParser::UnresolvedOperand operand;
117   Type type;
118   if (parser.parseOperand(operand) || parser.parseColonType(type) ||
119       parser.resolveOperand(operand, type, result.operands))
120     return failure();
121   return success();
122 }
123 
124 /// Parse optional operand and its type wrapped in parenthesis prefixed with
125 /// a keyword.
126 /// Example:
127 ///   keyword `(` %vectorLength: i64 `)`
128 static OptionalParseResult parseOptionalOperandAndType(OpAsmParser &parser,
129                                                        StringRef keyword,
130                                                        OperationState &result) {
131   OpAsmParser::UnresolvedOperand operand;
132   if (succeeded(parser.parseOptionalKeyword(keyword))) {
133     return failure(parser.parseLParen() ||
134                    parseOperandAndType(parser, result) || parser.parseRParen());
135   }
136   return llvm::None;
137 }
138 
139 /// Parse optional operand and its type wrapped in parenthesis.
140 /// Example:
141 ///   `(` %vectorLength: i64 `)`
142 static OptionalParseResult parseOptionalOperandAndType(OpAsmParser &parser,
143                                                        OperationState &result) {
144   if (succeeded(parser.parseOptionalLParen())) {
145     return failure(parseOperandAndType(parser, result) || parser.parseRParen());
146   }
147   return llvm::None;
148 }
149 
150 /// Parse optional operand with its type prefixed with prefixKeyword `=`.
151 /// Example:
152 ///   num=%gangNum: i32
153 static OptionalParseResult parserOptionalOperandAndTypeWithPrefix(
154     OpAsmParser &parser, OperationState &result, StringRef prefixKeyword) {
155   if (succeeded(parser.parseOptionalKeyword(prefixKeyword))) {
156     parser.parseEqual();
157     return parseOperandAndType(parser, result);
158   }
159   return llvm::None;
160 }
161 
162 static bool isComputeOperation(Operation *op) {
163   return isa<acc::ParallelOp>(op) || isa<acc::LoopOp>(op);
164 }
165 
166 namespace {
167 /// Pattern to remove operation without region that have constant false `ifCond`
168 /// and remove the condition from the operation if the `ifCond` is a true
169 /// constant.
170 template <typename OpTy>
171 struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
172   using OpRewritePattern<OpTy>::OpRewritePattern;
173 
174   LogicalResult matchAndRewrite(OpTy op,
175                                 PatternRewriter &rewriter) const override {
176     // Early return if there is no condition.
177     Value ifCond = op.ifCond();
178     if (!ifCond)
179       return success();
180 
181     IntegerAttr constAttr;
182     if (matchPattern(ifCond, m_Constant(&constAttr))) {
183       if (constAttr.getInt())
184         rewriter.updateRootInPlace(op, [&]() { op.ifCondMutable().erase(0); });
185       else
186         rewriter.eraseOp(op);
187     }
188 
189     return success();
190   }
191 };
192 } // namespace
193 
194 //===----------------------------------------------------------------------===//
195 // ParallelOp
196 //===----------------------------------------------------------------------===//
197 
198 /// Parse acc.parallel operation
199 /// operation := `acc.parallel` `async` `(` index `)`?
200 ///                             `wait` `(` index-list `)`?
201 ///                             `num_gangs` `(` value `)`?
202 ///                             `num_workers` `(` value `)`?
203 ///                             `vector_length` `(` value `)`?
204 ///                             `if` `(` value `)`?
205 ///                             `self` `(` value `)`?
206 ///                             `reduction` `(` value-list `)`?
207 ///                             `copy` `(` value-list `)`?
208 ///                             `copyin` `(` value-list `)`?
209 ///                             `copyin_readonly` `(` value-list `)`?
210 ///                             `copyout` `(` value-list `)`?
211 ///                             `copyout_zero` `(` value-list `)`?
212 ///                             `create` `(` value-list `)`?
213 ///                             `create_zero` `(` value-list `)`?
214 ///                             `no_create` `(` value-list `)`?
215 ///                             `present` `(` value-list `)`?
216 ///                             `deviceptr` `(` value-list `)`?
217 ///                             `attach` `(` value-list `)`?
218 ///                             `private` `(` value-list `)`?
219 ///                             `firstprivate` `(` value-list `)`?
220 ///                             region attr-dict?
221 ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
222   Builder &builder = parser.getBuilder();
223   SmallVector<OpAsmParser::UnresolvedOperand, 8> privateOperands,
224       firstprivateOperands, copyOperands, copyinOperands,
225       copyinReadonlyOperands, copyoutOperands, copyoutZeroOperands,
226       createOperands, createZeroOperands, noCreateOperands, presentOperands,
227       devicePtrOperands, attachOperands, waitOperands, reductionOperands;
228   SmallVector<Type, 8> waitOperandTypes, reductionOperandTypes,
229       copyOperandTypes, copyinOperandTypes, copyinReadonlyOperandTypes,
230       copyoutOperandTypes, copyoutZeroOperandTypes, createOperandTypes,
231       createZeroOperandTypes, noCreateOperandTypes, presentOperandTypes,
232       deviceptrOperandTypes, attachOperandTypes, privateOperandTypes,
233       firstprivateOperandTypes;
234 
235   SmallVector<Type, 8> operandTypes;
236   OpAsmParser::UnresolvedOperand ifCond, selfCond;
237   bool hasIfCond = false, hasSelfCond = false;
238   OptionalParseResult async, numGangs, numWorkers, vectorLength;
239   Type i1Type = builder.getI1Type();
240 
241   // async()?
242   async = parseOptionalOperandAndType(parser, ParallelOp::getAsyncKeyword(),
243                                       result);
244   if (async.hasValue() && failed(*async))
245     return failure();
246 
247   // wait()?
248   if (failed(parseOperandList(parser, ParallelOp::getWaitKeyword(),
249                               waitOperands, waitOperandTypes, result)))
250     return failure();
251 
252   // num_gangs(value)?
253   numGangs = parseOptionalOperandAndType(
254       parser, ParallelOp::getNumGangsKeyword(), result);
255   if (numGangs.hasValue() && failed(*numGangs))
256     return failure();
257 
258   // num_workers(value)?
259   numWorkers = parseOptionalOperandAndType(
260       parser, ParallelOp::getNumWorkersKeyword(), result);
261   if (numWorkers.hasValue() && failed(*numWorkers))
262     return failure();
263 
264   // vector_length(value)?
265   vectorLength = parseOptionalOperandAndType(
266       parser, ParallelOp::getVectorLengthKeyword(), result);
267   if (vectorLength.hasValue() && failed(*vectorLength))
268     return failure();
269 
270   // if()?
271   if (failed(parseOptionalOperand(parser, ParallelOp::getIfKeyword(), ifCond,
272                                   i1Type, hasIfCond, result)))
273     return failure();
274 
275   // self()?
276   if (failed(parseOptionalOperand(parser, ParallelOp::getSelfKeyword(),
277                                   selfCond, i1Type, hasSelfCond, result)))
278     return failure();
279 
280   // reduction()?
281   if (failed(parseOperandList(parser, ParallelOp::getReductionKeyword(),
282                               reductionOperands, reductionOperandTypes,
283                               result)))
284     return failure();
285 
286   // copy()?
287   if (failed(parseOperandList(parser, ParallelOp::getCopyKeyword(),
288                               copyOperands, copyOperandTypes, result)))
289     return failure();
290 
291   // copyin()?
292   if (failed(parseOperandList(parser, ParallelOp::getCopyinKeyword(),
293                               copyinOperands, copyinOperandTypes, result)))
294     return failure();
295 
296   // copyin_readonly()?
297   if (failed(parseOperandList(parser, ParallelOp::getCopyinReadonlyKeyword(),
298                               copyinReadonlyOperands,
299                               copyinReadonlyOperandTypes, result)))
300     return failure();
301 
302   // copyout()?
303   if (failed(parseOperandList(parser, ParallelOp::getCopyoutKeyword(),
304                               copyoutOperands, copyoutOperandTypes, result)))
305     return failure();
306 
307   // copyout_zero()?
308   if (failed(parseOperandList(parser, ParallelOp::getCopyoutZeroKeyword(),
309                               copyoutZeroOperands, copyoutZeroOperandTypes,
310                               result)))
311     return failure();
312 
313   // create()?
314   if (failed(parseOperandList(parser, ParallelOp::getCreateKeyword(),
315                               createOperands, createOperandTypes, result)))
316     return failure();
317 
318   // create_zero()?
319   if (failed(parseOperandList(parser, ParallelOp::getCreateZeroKeyword(),
320                               createZeroOperands, createZeroOperandTypes,
321                               result)))
322     return failure();
323 
324   // no_create()?
325   if (failed(parseOperandList(parser, ParallelOp::getNoCreateKeyword(),
326                               noCreateOperands, noCreateOperandTypes, result)))
327     return failure();
328 
329   // present()?
330   if (failed(parseOperandList(parser, ParallelOp::getPresentKeyword(),
331                               presentOperands, presentOperandTypes, result)))
332     return failure();
333 
334   // deviceptr()?
335   if (failed(parseOperandList(parser, ParallelOp::getDevicePtrKeyword(),
336                               devicePtrOperands, deviceptrOperandTypes,
337                               result)))
338     return failure();
339 
340   // attach()?
341   if (failed(parseOperandList(parser, ParallelOp::getAttachKeyword(),
342                               attachOperands, attachOperandTypes, result)))
343     return failure();
344 
345   // private()?
346   if (failed(parseOperandList(parser, ParallelOp::getPrivateKeyword(),
347                               privateOperands, privateOperandTypes, result)))
348     return failure();
349 
350   // firstprivate()?
351   if (failed(parseOperandList(parser, ParallelOp::getFirstPrivateKeyword(),
352                               firstprivateOperands, firstprivateOperandTypes,
353                               result)))
354     return failure();
355 
356   // Parallel op region
357   if (failed(parseRegions<ParallelOp>(parser, result)))
358     return failure();
359 
360   result.addAttribute(
361       ParallelOp::getOperandSegmentSizeAttr(),
362       builder.getI32VectorAttr(
363           {static_cast<int32_t>(async.hasValue() ? 1 : 0),
364            static_cast<int32_t>(waitOperands.size()),
365            static_cast<int32_t>(numGangs.hasValue() ? 1 : 0),
366            static_cast<int32_t>(numWorkers.hasValue() ? 1 : 0),
367            static_cast<int32_t>(vectorLength.hasValue() ? 1 : 0),
368            static_cast<int32_t>(hasIfCond ? 1 : 0),
369            static_cast<int32_t>(hasSelfCond ? 1 : 0),
370            static_cast<int32_t>(reductionOperands.size()),
371            static_cast<int32_t>(copyOperands.size()),
372            static_cast<int32_t>(copyinOperands.size()),
373            static_cast<int32_t>(copyinReadonlyOperands.size()),
374            static_cast<int32_t>(copyoutOperands.size()),
375            static_cast<int32_t>(copyoutZeroOperands.size()),
376            static_cast<int32_t>(createOperands.size()),
377            static_cast<int32_t>(createZeroOperands.size()),
378            static_cast<int32_t>(noCreateOperands.size()),
379            static_cast<int32_t>(presentOperands.size()),
380            static_cast<int32_t>(devicePtrOperands.size()),
381            static_cast<int32_t>(attachOperands.size()),
382            static_cast<int32_t>(privateOperands.size()),
383            static_cast<int32_t>(firstprivateOperands.size())}));
384 
385   // Additional attributes
386   if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
387     return failure();
388 
389   return success();
390 }
391 
392 void ParallelOp::print(OpAsmPrinter &printer) {
393   // async()?
394   if (Value async = this->async())
395     printer << " " << ParallelOp::getAsyncKeyword() << "(" << async << ": "
396             << async.getType() << ")";
397 
398   // wait()?
399   printOperandList(waitOperands(), ParallelOp::getWaitKeyword(), printer);
400 
401   // num_gangs()?
402   if (Value numGangs = this->numGangs())
403     printer << " " << ParallelOp::getNumGangsKeyword() << "(" << numGangs
404             << ": " << numGangs.getType() << ")";
405 
406   // num_workers()?
407   if (Value numWorkers = this->numWorkers())
408     printer << " " << ParallelOp::getNumWorkersKeyword() << "(" << numWorkers
409             << ": " << numWorkers.getType() << ")";
410 
411   // vector_length()?
412   if (Value vectorLength = this->vectorLength())
413     printer << " " << ParallelOp::getVectorLengthKeyword() << "("
414             << vectorLength << ": " << vectorLength.getType() << ")";
415 
416   // if()?
417   if (Value ifCond = this->ifCond())
418     printer << " " << ParallelOp::getIfKeyword() << "(" << ifCond << ")";
419 
420   // self()?
421   if (Value selfCond = this->selfCond())
422     printer << " " << ParallelOp::getSelfKeyword() << "(" << selfCond << ")";
423 
424   // reduction()?
425   printOperandList(reductionOperands(), ParallelOp::getReductionKeyword(),
426                    printer);
427 
428   // copy()?
429   printOperandList(copyOperands(), ParallelOp::getCopyKeyword(), printer);
430 
431   // copyin()?
432   printOperandList(copyinOperands(), ParallelOp::getCopyinKeyword(), printer);
433 
434   // copyin_readonly()?
435   printOperandList(copyinReadonlyOperands(),
436                    ParallelOp::getCopyinReadonlyKeyword(), printer);
437 
438   // copyout()?
439   printOperandList(copyoutOperands(), ParallelOp::getCopyoutKeyword(), printer);
440 
441   // copyout_zero()?
442   printOperandList(copyoutZeroOperands(), ParallelOp::getCopyoutZeroKeyword(),
443                    printer);
444 
445   // create()?
446   printOperandList(createOperands(), ParallelOp::getCreateKeyword(), printer);
447 
448   // create_zero()?
449   printOperandList(createZeroOperands(), ParallelOp::getCreateZeroKeyword(),
450                    printer);
451 
452   // no_create()?
453   printOperandList(noCreateOperands(), ParallelOp::getNoCreateKeyword(),
454                    printer);
455 
456   // present()?
457   printOperandList(presentOperands(), ParallelOp::getPresentKeyword(), printer);
458 
459   // deviceptr()?
460   printOperandList(devicePtrOperands(), ParallelOp::getDevicePtrKeyword(),
461                    printer);
462 
463   // attach()?
464   printOperandList(attachOperands(), ParallelOp::getAttachKeyword(), printer);
465 
466   // private()?
467   printOperandList(gangPrivateOperands(), ParallelOp::getPrivateKeyword(),
468                    printer);
469 
470   // firstprivate()?
471   printOperandList(gangFirstPrivateOperands(),
472                    ParallelOp::getFirstPrivateKeyword(), printer);
473 
474   printer << ' ';
475   printer.printRegion(region(),
476                       /*printEntryBlockArgs=*/false,
477                       /*printBlockTerminators=*/true);
478   printer.printOptionalAttrDictWithKeyword(
479       (*this)->getAttrs(), ParallelOp::getOperandSegmentSizeAttr());
480 }
481 
482 unsigned ParallelOp::getNumDataOperands() {
483   return reductionOperands().size() + copyOperands().size() +
484          copyinOperands().size() + copyinReadonlyOperands().size() +
485          copyoutOperands().size() + copyoutZeroOperands().size() +
486          createOperands().size() + createZeroOperands().size() +
487          noCreateOperands().size() + presentOperands().size() +
488          devicePtrOperands().size() + attachOperands().size() +
489          gangPrivateOperands().size() + gangFirstPrivateOperands().size();
490 }
491 
492 Value ParallelOp::getDataOperand(unsigned i) {
493   unsigned numOptional = async() ? 1 : 0;
494   numOptional += numGangs() ? 1 : 0;
495   numOptional += numWorkers() ? 1 : 0;
496   numOptional += vectorLength() ? 1 : 0;
497   numOptional += ifCond() ? 1 : 0;
498   numOptional += selfCond() ? 1 : 0;
499   return getOperand(waitOperands().size() + numOptional + i);
500 }
501 
502 //===----------------------------------------------------------------------===//
503 // LoopOp
504 //===----------------------------------------------------------------------===//
505 
506 /// Parse acc.loop operation
507 /// operation := `acc.loop`
508 ///              (`gang` ( `(` (`num=` value)? (`,` `static=` value `)`)? )? )?
509 ///              (`vector` ( `(` value `)` )? )? (`worker` (`(` value `)`)? )?
510 ///              (`vector_length` `(` value `)`)?
511 ///              (`tile` `(` value-list `)`)?
512 ///              (`private` `(` value-list `)`)?
513 ///              (`reduction` `(` value-list `)`)?
514 ///              region attr-dict?
515 ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) {
516   Builder &builder = parser.getBuilder();
517   unsigned executionMapping = OpenACCExecMapping::NONE;
518   SmallVector<Type, 8> operandTypes;
519   SmallVector<OpAsmParser::UnresolvedOperand, 8> privateOperands,
520       reductionOperands;
521   SmallVector<OpAsmParser::UnresolvedOperand, 8> tileOperands;
522   OptionalParseResult gangNum, gangStatic, worker, vector;
523 
524   // gang?
525   if (succeeded(parser.parseOptionalKeyword(LoopOp::getGangKeyword())))
526     executionMapping |= OpenACCExecMapping::GANG;
527 
528   // optional gang operand
529   if (succeeded(parser.parseOptionalLParen())) {
530     gangNum = parserOptionalOperandAndTypeWithPrefix(
531         parser, result, LoopOp::getGangNumKeyword());
532     if (gangNum.hasValue() && failed(*gangNum))
533       return failure();
534     parser.parseOptionalComma();
535     gangStatic = parserOptionalOperandAndTypeWithPrefix(
536         parser, result, LoopOp::getGangStaticKeyword());
537     if (gangStatic.hasValue() && failed(*gangStatic))
538       return failure();
539     parser.parseOptionalComma();
540     if (failed(parser.parseRParen()))
541       return failure();
542   }
543 
544   // worker?
545   if (succeeded(parser.parseOptionalKeyword(LoopOp::getWorkerKeyword())))
546     executionMapping |= OpenACCExecMapping::WORKER;
547 
548   // optional worker operand
549   worker = parseOptionalOperandAndType(parser, result);
550   if (worker.hasValue() && failed(*worker))
551     return failure();
552 
553   // vector?
554   if (succeeded(parser.parseOptionalKeyword(LoopOp::getVectorKeyword())))
555     executionMapping |= OpenACCExecMapping::VECTOR;
556 
557   // optional vector operand
558   vector = parseOptionalOperandAndType(parser, result);
559   if (vector.hasValue() && failed(*vector))
560     return failure();
561 
562   // tile()?
563   if (failed(parseOperandList(parser, LoopOp::getTileKeyword(), tileOperands,
564                               operandTypes, result)))
565     return failure();
566 
567   // private()?
568   if (failed(parseOperandList(parser, LoopOp::getPrivateKeyword(),
569                               privateOperands, operandTypes, result)))
570     return failure();
571 
572   // reduction()?
573   if (failed(parseOperandList(parser, LoopOp::getReductionKeyword(),
574                               reductionOperands, operandTypes, result)))
575     return failure();
576 
577   if (executionMapping != acc::OpenACCExecMapping::NONE)
578     result.addAttribute(LoopOp::getExecutionMappingAttrName(),
579                         builder.getI64IntegerAttr(executionMapping));
580 
581   // Parse optional results in case there is a reduce.
582   if (parser.parseOptionalArrowTypeList(result.types))
583     return failure();
584 
585   if (failed(parseRegions<LoopOp>(parser, result)))
586     return failure();
587 
588   result.addAttribute(LoopOp::getOperandSegmentSizeAttr(),
589                       builder.getI32VectorAttr(
590                           {static_cast<int32_t>(gangNum.hasValue() ? 1 : 0),
591                            static_cast<int32_t>(gangStatic.hasValue() ? 1 : 0),
592                            static_cast<int32_t>(worker.hasValue() ? 1 : 0),
593                            static_cast<int32_t>(vector.hasValue() ? 1 : 0),
594                            static_cast<int32_t>(tileOperands.size()),
595                            static_cast<int32_t>(privateOperands.size()),
596                            static_cast<int32_t>(reductionOperands.size())}));
597 
598   if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
599     return failure();
600 
601   return success();
602 }
603 
604 void LoopOp::print(OpAsmPrinter &printer) {
605   unsigned execMapping = exec_mapping();
606   if (execMapping & OpenACCExecMapping::GANG) {
607     printer << " " << LoopOp::getGangKeyword();
608     Value gangNum = this->gangNum();
609     Value gangStatic = this->gangStatic();
610 
611     // Print optional gang operands
612     if (gangNum || gangStatic) {
613       printer << "(";
614       if (gangNum) {
615         printer << LoopOp::getGangNumKeyword() << "=" << gangNum << ": "
616                 << gangNum.getType();
617         if (gangStatic)
618           printer << ", ";
619       }
620       if (gangStatic)
621         printer << LoopOp::getGangStaticKeyword() << "=" << gangStatic << ": "
622                 << gangStatic.getType();
623       printer << ")";
624     }
625   }
626 
627   if (execMapping & OpenACCExecMapping::WORKER) {
628     printer << " " << LoopOp::getWorkerKeyword();
629 
630     // Print optional worker operand if present
631     if (Value workerNum = this->workerNum())
632       printer << "(" << workerNum << ": " << workerNum.getType() << ")";
633   }
634 
635   if (execMapping & OpenACCExecMapping::VECTOR) {
636     printer << " " << LoopOp::getVectorKeyword();
637 
638     // Print optional vector operand if present
639     if (Value vectorLength = this->vectorLength())
640       printer << "(" << vectorLength << ": " << vectorLength.getType() << ")";
641   }
642 
643   // tile()?
644   printOperandList(tileOperands(), LoopOp::getTileKeyword(), printer);
645 
646   // private()?
647   printOperandList(privateOperands(), LoopOp::getPrivateKeyword(), printer);
648 
649   // reduction()?
650   printOperandList(reductionOperands(), LoopOp::getReductionKeyword(), printer);
651 
652   if (getNumResults() > 0)
653     printer << " -> (" << getResultTypes() << ")";
654 
655   printer << ' ';
656   printer.printRegion(region(),
657                       /*printEntryBlockArgs=*/false,
658                       /*printBlockTerminators=*/true);
659 
660   printer.printOptionalAttrDictWithKeyword(
661       (*this)->getAttrs(), {LoopOp::getExecutionMappingAttrName(),
662                             LoopOp::getOperandSegmentSizeAttr()});
663 }
664 
665 LogicalResult acc::LoopOp::verify() {
666   // auto, independent and seq attribute are mutually exclusive.
667   if ((auto_() && (independent() || seq())) || (independent() && seq())) {
668     return emitError("only one of " + acc::LoopOp::getAutoAttrName() + ", " +
669                      acc::LoopOp::getIndependentAttrName() + ", " +
670                      acc::LoopOp::getSeqAttrName() +
671                      " can be present at the same time");
672   }
673 
674   // Gang, worker and vector are incompatible with seq.
675   if (seq() && exec_mapping() != OpenACCExecMapping::NONE)
676     return emitError("gang, worker or vector cannot appear with the seq attr");
677 
678   // Check non-empty body().
679   if (region().empty())
680     return emitError("expected non-empty body.");
681 
682   return success();
683 }
684 
685 //===----------------------------------------------------------------------===//
686 // DataOp
687 //===----------------------------------------------------------------------===//
688 
689 LogicalResult acc::DataOp::verify() {
690   // 2.6.5. Data Construct restriction
691   // At least one copy, copyin, copyout, create, no_create, present, deviceptr,
692   // attach, or default clause must appear on a data construct.
693   if (getOperands().empty() && !defaultAttr())
694     return emitError("at least one operand or the default attribute "
695                      "must appear on the data operation");
696   return success();
697 }
698 
699 unsigned DataOp::getNumDataOperands() {
700   return copyOperands().size() + copyinOperands().size() +
701          copyinReadonlyOperands().size() + copyoutOperands().size() +
702          copyoutZeroOperands().size() + createOperands().size() +
703          createZeroOperands().size() + noCreateOperands().size() +
704          presentOperands().size() + deviceptrOperands().size() +
705          attachOperands().size();
706 }
707 
708 Value DataOp::getDataOperand(unsigned i) {
709   unsigned numOptional = ifCond() ? 1 : 0;
710   return getOperand(numOptional + i);
711 }
712 
713 //===----------------------------------------------------------------------===//
714 // ExitDataOp
715 //===----------------------------------------------------------------------===//
716 
717 LogicalResult acc::ExitDataOp::verify() {
718   // 2.6.6. Data Exit Directive restriction
719   // At least one copyout, delete, or detach clause must appear on an exit data
720   // directive.
721   if (copyoutOperands().empty() && deleteOperands().empty() &&
722       detachOperands().empty())
723     return emitError(
724         "at least one operand in copyout, delete or detach must appear on the "
725         "exit data operation");
726 
727   // The async attribute represent the async clause without value. Therefore the
728   // attribute and operand cannot appear at the same time.
729   if (asyncOperand() && async())
730     return emitError("async attribute cannot appear with asyncOperand");
731 
732   // The wait attribute represent the wait clause without values. Therefore the
733   // attribute and operands cannot appear at the same time.
734   if (!waitOperands().empty() && wait())
735     return emitError("wait attribute cannot appear with waitOperands");
736 
737   if (waitDevnum() && waitOperands().empty())
738     return emitError("wait_devnum cannot appear without waitOperands");
739 
740   return success();
741 }
742 
743 unsigned ExitDataOp::getNumDataOperands() {
744   return copyoutOperands().size() + deleteOperands().size() +
745          detachOperands().size();
746 }
747 
748 Value ExitDataOp::getDataOperand(unsigned i) {
749   unsigned numOptional = ifCond() ? 1 : 0;
750   numOptional += asyncOperand() ? 1 : 0;
751   numOptional += waitDevnum() ? 1 : 0;
752   return getOperand(waitOperands().size() + numOptional + i);
753 }
754 
755 void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
756                                              MLIRContext *context) {
757   results.add<RemoveConstantIfCondition<ExitDataOp>>(context);
758 }
759 
760 //===----------------------------------------------------------------------===//
761 // EnterDataOp
762 //===----------------------------------------------------------------------===//
763 
764 LogicalResult acc::EnterDataOp::verify() {
765   // 2.6.6. Data Enter Directive restriction
766   // At least one copyin, create, or attach clause must appear on an enter data
767   // directive.
768   if (copyinOperands().empty() && createOperands().empty() &&
769       createZeroOperands().empty() && attachOperands().empty())
770     return emitError(
771         "at least one operand in copyin, create, "
772         "create_zero or attach must appear on the enter data operation");
773 
774   // The async attribute represent the async clause without value. Therefore the
775   // attribute and operand cannot appear at the same time.
776   if (asyncOperand() && async())
777     return emitError("async attribute cannot appear with asyncOperand");
778 
779   // The wait attribute represent the wait clause without values. Therefore the
780   // attribute and operands cannot appear at the same time.
781   if (!waitOperands().empty() && wait())
782     return emitError("wait attribute cannot appear with waitOperands");
783 
784   if (waitDevnum() && waitOperands().empty())
785     return emitError("wait_devnum cannot appear without waitOperands");
786 
787   return success();
788 }
789 
790 unsigned EnterDataOp::getNumDataOperands() {
791   return copyinOperands().size() + createOperands().size() +
792          createZeroOperands().size() + attachOperands().size();
793 }
794 
795 Value EnterDataOp::getDataOperand(unsigned i) {
796   unsigned numOptional = ifCond() ? 1 : 0;
797   numOptional += asyncOperand() ? 1 : 0;
798   numOptional += waitDevnum() ? 1 : 0;
799   return getOperand(waitOperands().size() + numOptional + i);
800 }
801 
802 void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
803                                               MLIRContext *context) {
804   results.add<RemoveConstantIfCondition<EnterDataOp>>(context);
805 }
806 
807 //===----------------------------------------------------------------------===//
808 // InitOp
809 //===----------------------------------------------------------------------===//
810 
811 LogicalResult acc::InitOp::verify() {
812   Operation *currOp = *this;
813   while ((currOp = currOp->getParentOp()))
814     if (isComputeOperation(currOp))
815       return emitOpError("cannot be nested in a compute operation");
816   return success();
817 }
818 
819 //===----------------------------------------------------------------------===//
820 // ShutdownOp
821 //===----------------------------------------------------------------------===//
822 
823 LogicalResult acc::ShutdownOp::verify() {
824   Operation *currOp = *this;
825   while ((currOp = currOp->getParentOp()))
826     if (isComputeOperation(currOp))
827       return emitOpError("cannot be nested in a compute operation");
828   return success();
829 }
830 
831 //===----------------------------------------------------------------------===//
832 // UpdateOp
833 //===----------------------------------------------------------------------===//
834 
835 LogicalResult acc::UpdateOp::verify() {
836   // At least one of host or device should have a value.
837   if (hostOperands().empty() && deviceOperands().empty())
838     return emitError(
839         "at least one value must be present in hostOperands or deviceOperands");
840 
841   // The async attribute represent the async clause without value. Therefore the
842   // attribute and operand cannot appear at the same time.
843   if (asyncOperand() && async())
844     return emitError("async attribute cannot appear with asyncOperand");
845 
846   // The wait attribute represent the wait clause without values. Therefore the
847   // attribute and operands cannot appear at the same time.
848   if (!waitOperands().empty() && wait())
849     return emitError("wait attribute cannot appear with waitOperands");
850 
851   if (waitDevnum() && waitOperands().empty())
852     return emitError("wait_devnum cannot appear without waitOperands");
853 
854   return success();
855 }
856 
857 unsigned UpdateOp::getNumDataOperands() {
858   return hostOperands().size() + deviceOperands().size();
859 }
860 
861 Value UpdateOp::getDataOperand(unsigned i) {
862   unsigned numOptional = asyncOperand() ? 1 : 0;
863   numOptional += waitDevnum() ? 1 : 0;
864   numOptional += ifCond() ? 1 : 0;
865   return getOperand(waitOperands().size() + deviceTypeOperands().size() +
866                     numOptional + i);
867 }
868 
869 void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results,
870                                            MLIRContext *context) {
871   results.add<RemoveConstantIfCondition<UpdateOp>>(context);
872 }
873 
874 //===----------------------------------------------------------------------===//
875 // WaitOp
876 //===----------------------------------------------------------------------===//
877 
878 LogicalResult acc::WaitOp::verify() {
879   // The async attribute represent the async clause without value. Therefore the
880   // attribute and operand cannot appear at the same time.
881   if (asyncOperand() && async())
882     return emitError("async attribute cannot appear with asyncOperand");
883 
884   if (waitDevnum() && waitOperands().empty())
885     return emitError("wait_devnum cannot appear without waitOperands");
886 
887   return success();
888 }
889 
890 #define GET_OP_CLASSES
891 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
892 
893 #define GET_ATTRDEF_CLASSES
894 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
895