1 //===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 // =============================================================================
8 
9 #include "mlir/Dialect/OpenACC/OpenACC.h"
10 #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "mlir/IR/DialectImplementation.h"
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/OpImplementation.h"
16 #include "mlir/Transforms/DialectConversion.h"
17 #include "llvm/ADT/TypeSwitch.h"
18 
19 using namespace mlir;
20 using namespace acc;
21 
22 #include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
23 
24 //===----------------------------------------------------------------------===//
25 // OpenACC operations
26 //===----------------------------------------------------------------------===//
27 
28 void OpenACCDialect::initialize() {
29   addOperations<
30 #define GET_OP_LIST
31 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
32       >();
33   addAttributes<
34 #define GET_ATTRDEF_LIST
35 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
36       >();
37 }
38 
39 template <typename StructureOp>
40 static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
41                                 unsigned nRegions = 1) {
42 
43   SmallVector<Region *, 2> regions;
44   for (unsigned i = 0; i < nRegions; ++i)
45     regions.push_back(state.addRegion());
46 
47   for (Region *region : regions) {
48     if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{}))
49       return failure();
50   }
51 
52   return success();
53 }
54 
55 static ParseResult
56 parseOperandList(OpAsmParser &parser, StringRef keyword,
57                  SmallVectorImpl<OpAsmParser::OperandType> &args,
58                  SmallVectorImpl<Type> &argTypes, OperationState &result) {
59   if (failed(parser.parseOptionalKeyword(keyword)))
60     return success();
61 
62   if (failed(parser.parseLParen()))
63     return failure();
64 
65   // Exit early if the list is empty.
66   if (succeeded(parser.parseOptionalRParen()))
67     return success();
68 
69   do {
70     OpAsmParser::OperandType arg;
71     Type type;
72 
73     if (parser.parseRegionArgument(arg) || parser.parseColonType(type))
74       return failure();
75 
76     args.push_back(arg);
77     argTypes.push_back(type);
78   } while (succeeded(parser.parseOptionalComma()));
79 
80   if (failed(parser.parseRParen()))
81     return failure();
82 
83   return parser.resolveOperands(args, argTypes, parser.getCurrentLocation(),
84                                 result.operands);
85 }
86 
87 static void printOperandList(Operation::operand_range operands,
88                              StringRef listName, OpAsmPrinter &printer) {
89 
90   if (!operands.empty()) {
91     printer << " " << listName << "(";
92     llvm::interleaveComma(operands, printer, [&](Value op) {
93       printer << op << ": " << op.getType();
94     });
95     printer << ")";
96   }
97 }
98 
99 static ParseResult parseOptionalOperand(OpAsmParser &parser, StringRef keyword,
100                                         OpAsmParser::OperandType &operand,
101                                         Type type, bool &hasOptional,
102                                         OperationState &result) {
103   hasOptional = false;
104   if (succeeded(parser.parseOptionalKeyword(keyword))) {
105     hasOptional = true;
106     if (parser.parseLParen() || parser.parseOperand(operand) ||
107         parser.resolveOperand(operand, type, result.operands) ||
108         parser.parseRParen())
109       return failure();
110   }
111   return success();
112 }
113 
114 static ParseResult parseOperandAndType(OpAsmParser &parser,
115                                        OperationState &result) {
116   OpAsmParser::OperandType operand;
117   Type type;
118   if (parser.parseOperand(operand) || parser.parseColonType(type) ||
119       parser.resolveOperand(operand, type, result.operands))
120     return failure();
121   return success();
122 }
123 
124 /// Parse optional operand and its type wrapped in parenthesis prefixed with
125 /// a keyword.
126 /// Example:
127 ///   keyword `(` %vectorLength: i64 `)`
128 static OptionalParseResult parseOptionalOperandAndType(OpAsmParser &parser,
129                                                        StringRef keyword,
130                                                        OperationState &result) {
131   OpAsmParser::OperandType operand;
132   if (succeeded(parser.parseOptionalKeyword(keyword))) {
133     return failure(parser.parseLParen() ||
134                    parseOperandAndType(parser, result) || parser.parseRParen());
135   }
136   return llvm::None;
137 }
138 
139 /// Parse optional operand and its type wrapped in parenthesis.
140 /// Example:
141 ///   `(` %vectorLength: i64 `)`
142 static OptionalParseResult parseOptionalOperandAndType(OpAsmParser &parser,
143                                                        OperationState &result) {
144   if (succeeded(parser.parseOptionalLParen())) {
145     return failure(parseOperandAndType(parser, result) || parser.parseRParen());
146   }
147   return llvm::None;
148 }
149 
150 /// Parse optional operand with its type prefixed with prefixKeyword `=`.
151 /// Example:
152 ///   num=%gangNum: i32
153 static OptionalParseResult parserOptionalOperandAndTypeWithPrefix(
154     OpAsmParser &parser, OperationState &result, StringRef prefixKeyword) {
155   if (succeeded(parser.parseOptionalKeyword(prefixKeyword))) {
156     parser.parseEqual();
157     return parseOperandAndType(parser, result);
158   }
159   return llvm::None;
160 }
161 
162 static bool isComputeOperation(Operation *op) {
163   return isa<acc::ParallelOp>(op) || isa<acc::LoopOp>(op);
164 }
165 
166 namespace {
167 /// Pattern to remove operation without region that have constant false `ifCond`
168 /// and remove the condition from the operation if the `ifCond` is a true
169 /// constant.
170 template <typename OpTy>
171 struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
172   using OpRewritePattern<OpTy>::OpRewritePattern;
173 
174   LogicalResult matchAndRewrite(OpTy op,
175                                 PatternRewriter &rewriter) const override {
176     // Early return if there is no condition.
177     Value ifCond = op.ifCond();
178     if (!ifCond)
179       return success();
180 
181     IntegerAttr constAttr;
182     if (matchPattern(ifCond, m_Constant(&constAttr))) {
183       if (constAttr.getInt())
184         rewriter.updateRootInPlace(op, [&]() { op.ifCondMutable().erase(0); });
185       else
186         rewriter.eraseOp(op);
187     }
188 
189     return success();
190   }
191 };
192 } // namespace
193 
194 //===----------------------------------------------------------------------===//
195 // ParallelOp
196 //===----------------------------------------------------------------------===//
197 
198 /// Parse acc.parallel operation
199 /// operation := `acc.parallel` `async` `(` index `)`?
200 ///                             `wait` `(` index-list `)`?
201 ///                             `num_gangs` `(` value `)`?
202 ///                             `num_workers` `(` value `)`?
203 ///                             `vector_length` `(` value `)`?
204 ///                             `if` `(` value `)`?
205 ///                             `self` `(` value `)`?
206 ///                             `reduction` `(` value-list `)`?
207 ///                             `copy` `(` value-list `)`?
208 ///                             `copyin` `(` value-list `)`?
209 ///                             `copyin_readonly` `(` value-list `)`?
210 ///                             `copyout` `(` value-list `)`?
211 ///                             `copyout_zero` `(` value-list `)`?
212 ///                             `create` `(` value-list `)`?
213 ///                             `create_zero` `(` value-list `)`?
214 ///                             `no_create` `(` value-list `)`?
215 ///                             `present` `(` value-list `)`?
216 ///                             `deviceptr` `(` value-list `)`?
217 ///                             `attach` `(` value-list `)`?
218 ///                             `private` `(` value-list `)`?
219 ///                             `firstprivate` `(` value-list `)`?
220 ///                             region attr-dict?
221 ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
222   Builder &builder = parser.getBuilder();
223   SmallVector<OpAsmParser::OperandType, 8> privateOperands,
224       firstprivateOperands, copyOperands, copyinOperands,
225       copyinReadonlyOperands, copyoutOperands, copyoutZeroOperands,
226       createOperands, createZeroOperands, noCreateOperands, presentOperands,
227       devicePtrOperands, attachOperands, waitOperands, reductionOperands;
228   SmallVector<Type, 8> waitOperandTypes, reductionOperandTypes,
229       copyOperandTypes, copyinOperandTypes, copyinReadonlyOperandTypes,
230       copyoutOperandTypes, copyoutZeroOperandTypes, createOperandTypes,
231       createZeroOperandTypes, noCreateOperandTypes, presentOperandTypes,
232       deviceptrOperandTypes, attachOperandTypes, privateOperandTypes,
233       firstprivateOperandTypes;
234 
235   SmallVector<Type, 8> operandTypes;
236   OpAsmParser::OperandType ifCond, selfCond;
237   bool hasIfCond = false, hasSelfCond = false;
238   OptionalParseResult async, numGangs, numWorkers, vectorLength;
239   Type i1Type = builder.getI1Type();
240 
241   // async()?
242   async = parseOptionalOperandAndType(parser, ParallelOp::getAsyncKeyword(),
243                                       result);
244   if (async.hasValue() && failed(*async))
245     return failure();
246 
247   // wait()?
248   if (failed(parseOperandList(parser, ParallelOp::getWaitKeyword(),
249                               waitOperands, waitOperandTypes, result)))
250     return failure();
251 
252   // num_gangs(value)?
253   numGangs = parseOptionalOperandAndType(
254       parser, ParallelOp::getNumGangsKeyword(), result);
255   if (numGangs.hasValue() && failed(*numGangs))
256     return failure();
257 
258   // num_workers(value)?
259   numWorkers = parseOptionalOperandAndType(
260       parser, ParallelOp::getNumWorkersKeyword(), result);
261   if (numWorkers.hasValue() && failed(*numWorkers))
262     return failure();
263 
264   // vector_length(value)?
265   vectorLength = parseOptionalOperandAndType(
266       parser, ParallelOp::getVectorLengthKeyword(), result);
267   if (vectorLength.hasValue() && failed(*vectorLength))
268     return failure();
269 
270   // if()?
271   if (failed(parseOptionalOperand(parser, ParallelOp::getIfKeyword(), ifCond,
272                                   i1Type, hasIfCond, result)))
273     return failure();
274 
275   // self()?
276   if (failed(parseOptionalOperand(parser, ParallelOp::getSelfKeyword(),
277                                   selfCond, i1Type, hasSelfCond, result)))
278     return failure();
279 
280   // reduction()?
281   if (failed(parseOperandList(parser, ParallelOp::getReductionKeyword(),
282                               reductionOperands, reductionOperandTypes,
283                               result)))
284     return failure();
285 
286   // copy()?
287   if (failed(parseOperandList(parser, ParallelOp::getCopyKeyword(),
288                               copyOperands, copyOperandTypes, result)))
289     return failure();
290 
291   // copyin()?
292   if (failed(parseOperandList(parser, ParallelOp::getCopyinKeyword(),
293                               copyinOperands, copyinOperandTypes, result)))
294     return failure();
295 
296   // copyin_readonly()?
297   if (failed(parseOperandList(parser, ParallelOp::getCopyinReadonlyKeyword(),
298                               copyinReadonlyOperands,
299                               copyinReadonlyOperandTypes, result)))
300     return failure();
301 
302   // copyout()?
303   if (failed(parseOperandList(parser, ParallelOp::getCopyoutKeyword(),
304                               copyoutOperands, copyoutOperandTypes, result)))
305     return failure();
306 
307   // copyout_zero()?
308   if (failed(parseOperandList(parser, ParallelOp::getCopyoutZeroKeyword(),
309                               copyoutZeroOperands, copyoutZeroOperandTypes,
310                               result)))
311     return failure();
312 
313   // create()?
314   if (failed(parseOperandList(parser, ParallelOp::getCreateKeyword(),
315                               createOperands, createOperandTypes, result)))
316     return failure();
317 
318   // create_zero()?
319   if (failed(parseOperandList(parser, ParallelOp::getCreateZeroKeyword(),
320                               createZeroOperands, createZeroOperandTypes,
321                               result)))
322     return failure();
323 
324   // no_create()?
325   if (failed(parseOperandList(parser, ParallelOp::getNoCreateKeyword(),
326                               noCreateOperands, noCreateOperandTypes, result)))
327     return failure();
328 
329   // present()?
330   if (failed(parseOperandList(parser, ParallelOp::getPresentKeyword(),
331                               presentOperands, presentOperandTypes, result)))
332     return failure();
333 
334   // deviceptr()?
335   if (failed(parseOperandList(parser, ParallelOp::getDevicePtrKeyword(),
336                               devicePtrOperands, deviceptrOperandTypes,
337                               result)))
338     return failure();
339 
340   // attach()?
341   if (failed(parseOperandList(parser, ParallelOp::getAttachKeyword(),
342                               attachOperands, attachOperandTypes, result)))
343     return failure();
344 
345   // private()?
346   if (failed(parseOperandList(parser, ParallelOp::getPrivateKeyword(),
347                               privateOperands, privateOperandTypes, result)))
348     return failure();
349 
350   // firstprivate()?
351   if (failed(parseOperandList(parser, ParallelOp::getFirstPrivateKeyword(),
352                               firstprivateOperands, firstprivateOperandTypes,
353                               result)))
354     return failure();
355 
356   // Parallel op region
357   if (failed(parseRegions<ParallelOp>(parser, result)))
358     return failure();
359 
360   result.addAttribute(
361       ParallelOp::getOperandSegmentSizeAttr(),
362       builder.getI32VectorAttr(
363           {static_cast<int32_t>(async.hasValue() ? 1 : 0),
364            static_cast<int32_t>(waitOperands.size()),
365            static_cast<int32_t>(numGangs.hasValue() ? 1 : 0),
366            static_cast<int32_t>(numWorkers.hasValue() ? 1 : 0),
367            static_cast<int32_t>(vectorLength.hasValue() ? 1 : 0),
368            static_cast<int32_t>(hasIfCond ? 1 : 0),
369            static_cast<int32_t>(hasSelfCond ? 1 : 0),
370            static_cast<int32_t>(reductionOperands.size()),
371            static_cast<int32_t>(copyOperands.size()),
372            static_cast<int32_t>(copyinOperands.size()),
373            static_cast<int32_t>(copyinReadonlyOperands.size()),
374            static_cast<int32_t>(copyoutOperands.size()),
375            static_cast<int32_t>(copyoutZeroOperands.size()),
376            static_cast<int32_t>(createOperands.size()),
377            static_cast<int32_t>(createZeroOperands.size()),
378            static_cast<int32_t>(noCreateOperands.size()),
379            static_cast<int32_t>(presentOperands.size()),
380            static_cast<int32_t>(devicePtrOperands.size()),
381            static_cast<int32_t>(attachOperands.size()),
382            static_cast<int32_t>(privateOperands.size()),
383            static_cast<int32_t>(firstprivateOperands.size())}));
384 
385   // Additional attributes
386   if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
387     return failure();
388 
389   return success();
390 }
391 
392 void ParallelOp::print(OpAsmPrinter &printer) {
393   // async()?
394   if (Value async = this->async())
395     printer << " " << ParallelOp::getAsyncKeyword() << "(" << async << ": "
396             << async.getType() << ")";
397 
398   // wait()?
399   printOperandList(waitOperands(), ParallelOp::getWaitKeyword(), printer);
400 
401   // num_gangs()?
402   if (Value numGangs = this->numGangs())
403     printer << " " << ParallelOp::getNumGangsKeyword() << "(" << numGangs
404             << ": " << numGangs.getType() << ")";
405 
406   // num_workers()?
407   if (Value numWorkers = this->numWorkers())
408     printer << " " << ParallelOp::getNumWorkersKeyword() << "(" << numWorkers
409             << ": " << numWorkers.getType() << ")";
410 
411   // vector_length()?
412   if (Value vectorLength = this->vectorLength())
413     printer << " " << ParallelOp::getVectorLengthKeyword() << "("
414             << vectorLength << ": " << vectorLength.getType() << ")";
415 
416   // if()?
417   if (Value ifCond = this->ifCond())
418     printer << " " << ParallelOp::getIfKeyword() << "(" << ifCond << ")";
419 
420   // self()?
421   if (Value selfCond = this->selfCond())
422     printer << " " << ParallelOp::getSelfKeyword() << "(" << selfCond << ")";
423 
424   // reduction()?
425   printOperandList(reductionOperands(), ParallelOp::getReductionKeyword(),
426                    printer);
427 
428   // copy()?
429   printOperandList(copyOperands(), ParallelOp::getCopyKeyword(), printer);
430 
431   // copyin()?
432   printOperandList(copyinOperands(), ParallelOp::getCopyinKeyword(), printer);
433 
434   // copyin_readonly()?
435   printOperandList(copyinReadonlyOperands(),
436                    ParallelOp::getCopyinReadonlyKeyword(), printer);
437 
438   // copyout()?
439   printOperandList(copyoutOperands(), ParallelOp::getCopyoutKeyword(), printer);
440 
441   // copyout_zero()?
442   printOperandList(copyoutZeroOperands(), ParallelOp::getCopyoutZeroKeyword(),
443                    printer);
444 
445   // create()?
446   printOperandList(createOperands(), ParallelOp::getCreateKeyword(), printer);
447 
448   // create_zero()?
449   printOperandList(createZeroOperands(), ParallelOp::getCreateZeroKeyword(),
450                    printer);
451 
452   // no_create()?
453   printOperandList(noCreateOperands(), ParallelOp::getNoCreateKeyword(),
454                    printer);
455 
456   // present()?
457   printOperandList(presentOperands(), ParallelOp::getPresentKeyword(), printer);
458 
459   // deviceptr()?
460   printOperandList(devicePtrOperands(), ParallelOp::getDevicePtrKeyword(),
461                    printer);
462 
463   // attach()?
464   printOperandList(attachOperands(), ParallelOp::getAttachKeyword(), printer);
465 
466   // private()?
467   printOperandList(gangPrivateOperands(), ParallelOp::getPrivateKeyword(),
468                    printer);
469 
470   // firstprivate()?
471   printOperandList(gangFirstPrivateOperands(),
472                    ParallelOp::getFirstPrivateKeyword(), printer);
473 
474   printer << ' ';
475   printer.printRegion(region(),
476                       /*printEntryBlockArgs=*/false,
477                       /*printBlockTerminators=*/true);
478   printer.printOptionalAttrDictWithKeyword(
479       (*this)->getAttrs(), ParallelOp::getOperandSegmentSizeAttr());
480 }
481 
482 unsigned ParallelOp::getNumDataOperands() {
483   return reductionOperands().size() + copyOperands().size() +
484          copyinOperands().size() + copyinReadonlyOperands().size() +
485          copyoutOperands().size() + copyoutZeroOperands().size() +
486          createOperands().size() + createZeroOperands().size() +
487          noCreateOperands().size() + presentOperands().size() +
488          devicePtrOperands().size() + attachOperands().size() +
489          gangPrivateOperands().size() + gangFirstPrivateOperands().size();
490 }
491 
492 Value ParallelOp::getDataOperand(unsigned i) {
493   unsigned numOptional = async() ? 1 : 0;
494   numOptional += numGangs() ? 1 : 0;
495   numOptional += numWorkers() ? 1 : 0;
496   numOptional += vectorLength() ? 1 : 0;
497   numOptional += ifCond() ? 1 : 0;
498   numOptional += selfCond() ? 1 : 0;
499   return getOperand(waitOperands().size() + numOptional + i);
500 }
501 
502 //===----------------------------------------------------------------------===//
503 // LoopOp
504 //===----------------------------------------------------------------------===//
505 
506 /// Parse acc.loop operation
507 /// operation := `acc.loop`
508 ///              (`gang` ( `(` (`num=` value)? (`,` `static=` value `)`)? )? )?
509 ///              (`vector` ( `(` value `)` )? )? (`worker` (`(` value `)`)? )?
510 ///              (`vector_length` `(` value `)`)?
511 ///              (`tile` `(` value-list `)`)?
512 ///              (`private` `(` value-list `)`)?
513 ///              (`reduction` `(` value-list `)`)?
514 ///              region attr-dict?
515 ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) {
516   Builder &builder = parser.getBuilder();
517   unsigned executionMapping = OpenACCExecMapping::NONE;
518   SmallVector<Type, 8> operandTypes;
519   SmallVector<OpAsmParser::OperandType, 8> privateOperands, reductionOperands;
520   SmallVector<OpAsmParser::OperandType, 8> tileOperands;
521   OptionalParseResult gangNum, gangStatic, worker, vector;
522 
523   // gang?
524   if (succeeded(parser.parseOptionalKeyword(LoopOp::getGangKeyword())))
525     executionMapping |= OpenACCExecMapping::GANG;
526 
527   // optional gang operand
528   if (succeeded(parser.parseOptionalLParen())) {
529     gangNum = parserOptionalOperandAndTypeWithPrefix(
530         parser, result, LoopOp::getGangNumKeyword());
531     if (gangNum.hasValue() && failed(*gangNum))
532       return failure();
533     parser.parseOptionalComma();
534     gangStatic = parserOptionalOperandAndTypeWithPrefix(
535         parser, result, LoopOp::getGangStaticKeyword());
536     if (gangStatic.hasValue() && failed(*gangStatic))
537       return failure();
538     parser.parseOptionalComma();
539     if (failed(parser.parseRParen()))
540       return failure();
541   }
542 
543   // worker?
544   if (succeeded(parser.parseOptionalKeyword(LoopOp::getWorkerKeyword())))
545     executionMapping |= OpenACCExecMapping::WORKER;
546 
547   // optional worker operand
548   worker = parseOptionalOperandAndType(parser, result);
549   if (worker.hasValue() && failed(*worker))
550     return failure();
551 
552   // vector?
553   if (succeeded(parser.parseOptionalKeyword(LoopOp::getVectorKeyword())))
554     executionMapping |= OpenACCExecMapping::VECTOR;
555 
556   // optional vector operand
557   vector = parseOptionalOperandAndType(parser, result);
558   if (vector.hasValue() && failed(*vector))
559     return failure();
560 
561   // tile()?
562   if (failed(parseOperandList(parser, LoopOp::getTileKeyword(), tileOperands,
563                               operandTypes, result)))
564     return failure();
565 
566   // private()?
567   if (failed(parseOperandList(parser, LoopOp::getPrivateKeyword(),
568                               privateOperands, operandTypes, result)))
569     return failure();
570 
571   // reduction()?
572   if (failed(parseOperandList(parser, LoopOp::getReductionKeyword(),
573                               reductionOperands, operandTypes, result)))
574     return failure();
575 
576   if (executionMapping != acc::OpenACCExecMapping::NONE)
577     result.addAttribute(LoopOp::getExecutionMappingAttrName(),
578                         builder.getI64IntegerAttr(executionMapping));
579 
580   // Parse optional results in case there is a reduce.
581   if (parser.parseOptionalArrowTypeList(result.types))
582     return failure();
583 
584   if (failed(parseRegions<LoopOp>(parser, result)))
585     return failure();
586 
587   result.addAttribute(LoopOp::getOperandSegmentSizeAttr(),
588                       builder.getI32VectorAttr(
589                           {static_cast<int32_t>(gangNum.hasValue() ? 1 : 0),
590                            static_cast<int32_t>(gangStatic.hasValue() ? 1 : 0),
591                            static_cast<int32_t>(worker.hasValue() ? 1 : 0),
592                            static_cast<int32_t>(vector.hasValue() ? 1 : 0),
593                            static_cast<int32_t>(tileOperands.size()),
594                            static_cast<int32_t>(privateOperands.size()),
595                            static_cast<int32_t>(reductionOperands.size())}));
596 
597   if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
598     return failure();
599 
600   return success();
601 }
602 
603 void LoopOp::print(OpAsmPrinter &printer) {
604   unsigned execMapping = exec_mapping();
605   if (execMapping & OpenACCExecMapping::GANG) {
606     printer << " " << LoopOp::getGangKeyword();
607     Value gangNum = this->gangNum();
608     Value gangStatic = this->gangStatic();
609 
610     // Print optional gang operands
611     if (gangNum || gangStatic) {
612       printer << "(";
613       if (gangNum) {
614         printer << LoopOp::getGangNumKeyword() << "=" << gangNum << ": "
615                 << gangNum.getType();
616         if (gangStatic)
617           printer << ", ";
618       }
619       if (gangStatic)
620         printer << LoopOp::getGangStaticKeyword() << "=" << gangStatic << ": "
621                 << gangStatic.getType();
622       printer << ")";
623     }
624   }
625 
626   if (execMapping & OpenACCExecMapping::WORKER) {
627     printer << " " << LoopOp::getWorkerKeyword();
628 
629     // Print optional worker operand if present
630     if (Value workerNum = this->workerNum())
631       printer << "(" << workerNum << ": " << workerNum.getType() << ")";
632   }
633 
634   if (execMapping & OpenACCExecMapping::VECTOR) {
635     printer << " " << LoopOp::getVectorKeyword();
636 
637     // Print optional vector operand if present
638     if (Value vectorLength = this->vectorLength())
639       printer << "(" << vectorLength << ": " << vectorLength.getType() << ")";
640   }
641 
642   // tile()?
643   printOperandList(tileOperands(), LoopOp::getTileKeyword(), printer);
644 
645   // private()?
646   printOperandList(privateOperands(), LoopOp::getPrivateKeyword(), printer);
647 
648   // reduction()?
649   printOperandList(reductionOperands(), LoopOp::getReductionKeyword(), printer);
650 
651   if (getNumResults() > 0)
652     printer << " -> (" << getResultTypes() << ")";
653 
654   printer << ' ';
655   printer.printRegion(region(),
656                       /*printEntryBlockArgs=*/false,
657                       /*printBlockTerminators=*/true);
658 
659   printer.printOptionalAttrDictWithKeyword(
660       (*this)->getAttrs(), {LoopOp::getExecutionMappingAttrName(),
661                             LoopOp::getOperandSegmentSizeAttr()});
662 }
663 
664 LogicalResult acc::LoopOp::verify() {
665   // auto, independent and seq attribute are mutually exclusive.
666   if ((auto_() && (independent() || seq())) || (independent() && seq())) {
667     return emitError("only one of " + acc::LoopOp::getAutoAttrName() + ", " +
668                      acc::LoopOp::getIndependentAttrName() + ", " +
669                      acc::LoopOp::getSeqAttrName() +
670                      " can be present at the same time");
671   }
672 
673   // Gang, worker and vector are incompatible with seq.
674   if (seq() && exec_mapping() != OpenACCExecMapping::NONE)
675     return emitError("gang, worker or vector cannot appear with the seq attr");
676 
677   // Check non-empty body().
678   if (region().empty())
679     return emitError("expected non-empty body.");
680 
681   return success();
682 }
683 
684 //===----------------------------------------------------------------------===//
685 // DataOp
686 //===----------------------------------------------------------------------===//
687 
688 LogicalResult acc::DataOp::verify() {
689   // 2.6.5. Data Construct restriction
690   // At least one copy, copyin, copyout, create, no_create, present, deviceptr,
691   // attach, or default clause must appear on a data construct.
692   if (getOperands().empty() && !defaultAttr())
693     return emitError("at least one operand or the default attribute "
694                      "must appear on the data operation");
695   return success();
696 }
697 
698 unsigned DataOp::getNumDataOperands() {
699   return copyOperands().size() + copyinOperands().size() +
700          copyinReadonlyOperands().size() + copyoutOperands().size() +
701          copyoutZeroOperands().size() + createOperands().size() +
702          createZeroOperands().size() + noCreateOperands().size() +
703          presentOperands().size() + deviceptrOperands().size() +
704          attachOperands().size();
705 }
706 
707 Value DataOp::getDataOperand(unsigned i) {
708   unsigned numOptional = ifCond() ? 1 : 0;
709   return getOperand(numOptional + i);
710 }
711 
712 //===----------------------------------------------------------------------===//
713 // ExitDataOp
714 //===----------------------------------------------------------------------===//
715 
716 LogicalResult acc::ExitDataOp::verify() {
717   // 2.6.6. Data Exit Directive restriction
718   // At least one copyout, delete, or detach clause must appear on an exit data
719   // directive.
720   if (copyoutOperands().empty() && deleteOperands().empty() &&
721       detachOperands().empty())
722     return emitError(
723         "at least one operand in copyout, delete or detach must appear on the "
724         "exit data operation");
725 
726   // The async attribute represent the async clause without value. Therefore the
727   // attribute and operand cannot appear at the same time.
728   if (asyncOperand() && async())
729     return emitError("async attribute cannot appear with asyncOperand");
730 
731   // The wait attribute represent the wait clause without values. Therefore the
732   // attribute and operands cannot appear at the same time.
733   if (!waitOperands().empty() && wait())
734     return emitError("wait attribute cannot appear with waitOperands");
735 
736   if (waitDevnum() && waitOperands().empty())
737     return emitError("wait_devnum cannot appear without waitOperands");
738 
739   return success();
740 }
741 
742 unsigned ExitDataOp::getNumDataOperands() {
743   return copyoutOperands().size() + deleteOperands().size() +
744          detachOperands().size();
745 }
746 
747 Value ExitDataOp::getDataOperand(unsigned i) {
748   unsigned numOptional = ifCond() ? 1 : 0;
749   numOptional += asyncOperand() ? 1 : 0;
750   numOptional += waitDevnum() ? 1 : 0;
751   return getOperand(waitOperands().size() + numOptional + i);
752 }
753 
754 void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
755                                              MLIRContext *context) {
756   results.add<RemoveConstantIfCondition<ExitDataOp>>(context);
757 }
758 
759 //===----------------------------------------------------------------------===//
760 // EnterDataOp
761 //===----------------------------------------------------------------------===//
762 
763 LogicalResult acc::EnterDataOp::verify() {
764   // 2.6.6. Data Enter Directive restriction
765   // At least one copyin, create, or attach clause must appear on an enter data
766   // directive.
767   if (copyinOperands().empty() && createOperands().empty() &&
768       createZeroOperands().empty() && attachOperands().empty())
769     return emitError(
770         "at least one operand in copyin, create, "
771         "create_zero or attach must appear on the enter data operation");
772 
773   // The async attribute represent the async clause without value. Therefore the
774   // attribute and operand cannot appear at the same time.
775   if (asyncOperand() && async())
776     return emitError("async attribute cannot appear with asyncOperand");
777 
778   // The wait attribute represent the wait clause without values. Therefore the
779   // attribute and operands cannot appear at the same time.
780   if (!waitOperands().empty() && wait())
781     return emitError("wait attribute cannot appear with waitOperands");
782 
783   if (waitDevnum() && waitOperands().empty())
784     return emitError("wait_devnum cannot appear without waitOperands");
785 
786   return success();
787 }
788 
789 unsigned EnterDataOp::getNumDataOperands() {
790   return copyinOperands().size() + createOperands().size() +
791          createZeroOperands().size() + attachOperands().size();
792 }
793 
794 Value EnterDataOp::getDataOperand(unsigned i) {
795   unsigned numOptional = ifCond() ? 1 : 0;
796   numOptional += asyncOperand() ? 1 : 0;
797   numOptional += waitDevnum() ? 1 : 0;
798   return getOperand(waitOperands().size() + numOptional + i);
799 }
800 
801 void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
802                                               MLIRContext *context) {
803   results.add<RemoveConstantIfCondition<EnterDataOp>>(context);
804 }
805 
806 //===----------------------------------------------------------------------===//
807 // InitOp
808 //===----------------------------------------------------------------------===//
809 
810 LogicalResult acc::InitOp::verify() {
811   Operation *currOp = *this;
812   while ((currOp = currOp->getParentOp()))
813     if (isComputeOperation(currOp))
814       return emitOpError("cannot be nested in a compute operation");
815   return success();
816 }
817 
818 //===----------------------------------------------------------------------===//
819 // ShutdownOp
820 //===----------------------------------------------------------------------===//
821 
822 LogicalResult acc::ShutdownOp::verify() {
823   Operation *currOp = *this;
824   while ((currOp = currOp->getParentOp()))
825     if (isComputeOperation(currOp))
826       return emitOpError("cannot be nested in a compute operation");
827   return success();
828 }
829 
830 //===----------------------------------------------------------------------===//
831 // UpdateOp
832 //===----------------------------------------------------------------------===//
833 
834 LogicalResult acc::UpdateOp::verify() {
835   // At least one of host or device should have a value.
836   if (hostOperands().empty() && deviceOperands().empty())
837     return emitError(
838         "at least one value must be present in hostOperands or deviceOperands");
839 
840   // The async attribute represent the async clause without value. Therefore the
841   // attribute and operand cannot appear at the same time.
842   if (asyncOperand() && async())
843     return emitError("async attribute cannot appear with asyncOperand");
844 
845   // The wait attribute represent the wait clause without values. Therefore the
846   // attribute and operands cannot appear at the same time.
847   if (!waitOperands().empty() && wait())
848     return emitError("wait attribute cannot appear with waitOperands");
849 
850   if (waitDevnum() && waitOperands().empty())
851     return emitError("wait_devnum cannot appear without waitOperands");
852 
853   return success();
854 }
855 
856 unsigned UpdateOp::getNumDataOperands() {
857   return hostOperands().size() + deviceOperands().size();
858 }
859 
860 Value UpdateOp::getDataOperand(unsigned i) {
861   unsigned numOptional = asyncOperand() ? 1 : 0;
862   numOptional += waitDevnum() ? 1 : 0;
863   numOptional += ifCond() ? 1 : 0;
864   return getOperand(waitOperands().size() + deviceTypeOperands().size() +
865                     numOptional + i);
866 }
867 
868 void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results,
869                                            MLIRContext *context) {
870   results.add<RemoveConstantIfCondition<UpdateOp>>(context);
871 }
872 
873 //===----------------------------------------------------------------------===//
874 // WaitOp
875 //===----------------------------------------------------------------------===//
876 
877 LogicalResult acc::WaitOp::verify() {
878   // The async attribute represent the async clause without value. Therefore the
879   // attribute and operand cannot appear at the same time.
880   if (asyncOperand() && async())
881     return emitError("async attribute cannot appear with asyncOperand");
882 
883   if (waitDevnum() && waitOperands().empty())
884     return emitError("wait_devnum cannot appear without waitOperands");
885 
886   return success();
887 }
888 
889 #define GET_OP_CLASSES
890 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
891 
892 #define GET_ATTRDEF_CLASSES
893 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
894