1 //===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 // =============================================================================
8 
9 #include "mlir/Dialect/OpenACC/OpenACC.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
12 #include "mlir/Dialect/StandardOps/IR/Ops.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/BuiltinTypes.h"
15 #include "mlir/IR/DialectImplementation.h"
16 #include "mlir/IR/OpImplementation.h"
17 #include "mlir/Transforms/DialectConversion.h"
18 #include "llvm/ADT/TypeSwitch.h"
19 
20 using namespace mlir;
21 using namespace acc;
22 
23 #include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
24 
25 //===----------------------------------------------------------------------===//
26 // OpenACC operations
27 //===----------------------------------------------------------------------===//
28 
29 void OpenACCDialect::initialize() {
30   addOperations<
31 #define GET_OP_LIST
32 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
33       >();
34   addAttributes<
35 #define GET_ATTRDEF_LIST
36 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
37       >();
38 }
39 
40 template <typename StructureOp>
41 static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
42                                 unsigned nRegions = 1) {
43 
44   SmallVector<Region *, 2> regions;
45   for (unsigned i = 0; i < nRegions; ++i)
46     regions.push_back(state.addRegion());
47 
48   for (Region *region : regions) {
49     if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{}))
50       return failure();
51   }
52 
53   return success();
54 }
55 
56 static ParseResult
57 parseOperandList(OpAsmParser &parser, StringRef keyword,
58                  SmallVectorImpl<OpAsmParser::OperandType> &args,
59                  SmallVectorImpl<Type> &argTypes, OperationState &result) {
60   if (failed(parser.parseOptionalKeyword(keyword)))
61     return success();
62 
63   if (failed(parser.parseLParen()))
64     return failure();
65 
66   // Exit early if the list is empty.
67   if (succeeded(parser.parseOptionalRParen()))
68     return success();
69 
70   do {
71     OpAsmParser::OperandType arg;
72     Type type;
73 
74     if (parser.parseRegionArgument(arg) || parser.parseColonType(type))
75       return failure();
76 
77     args.push_back(arg);
78     argTypes.push_back(type);
79   } while (succeeded(parser.parseOptionalComma()));
80 
81   if (failed(parser.parseRParen()))
82     return failure();
83 
84   return parser.resolveOperands(args, argTypes, parser.getCurrentLocation(),
85                                 result.operands);
86 }
87 
88 static void printOperandList(Operation::operand_range operands,
89                              StringRef listName, OpAsmPrinter &printer) {
90 
91   if (!operands.empty()) {
92     printer << " " << listName << "(";
93     llvm::interleaveComma(operands, printer, [&](Value op) {
94       printer << op << ": " << op.getType();
95     });
96     printer << ")";
97   }
98 }
99 
100 static ParseResult parseOptionalOperand(OpAsmParser &parser, StringRef keyword,
101                                         OpAsmParser::OperandType &operand,
102                                         Type type, bool &hasOptional,
103                                         OperationState &result) {
104   hasOptional = false;
105   if (succeeded(parser.parseOptionalKeyword(keyword))) {
106     hasOptional = true;
107     if (parser.parseLParen() || parser.parseOperand(operand) ||
108         parser.resolveOperand(operand, type, result.operands) ||
109         parser.parseRParen())
110       return failure();
111   }
112   return success();
113 }
114 
115 static ParseResult parseOperandAndType(OpAsmParser &parser,
116                                        OperationState &result) {
117   OpAsmParser::OperandType operand;
118   Type type;
119   if (parser.parseOperand(operand) || parser.parseColonType(type) ||
120       parser.resolveOperand(operand, type, result.operands))
121     return failure();
122   return success();
123 }
124 
125 /// Parse optional operand and its type wrapped in parenthesis prefixed with
126 /// a keyword.
127 /// Example:
128 ///   keyword `(` %vectorLength: i64 `)`
129 static OptionalParseResult parseOptionalOperandAndType(OpAsmParser &parser,
130                                                        StringRef keyword,
131                                                        OperationState &result) {
132   OpAsmParser::OperandType operand;
133   if (succeeded(parser.parseOptionalKeyword(keyword))) {
134     return failure(parser.parseLParen() ||
135                    parseOperandAndType(parser, result) || parser.parseRParen());
136   }
137   return llvm::None;
138 }
139 
140 /// Parse optional operand and its type wrapped in parenthesis.
141 /// Example:
142 ///   `(` %vectorLength: i64 `)`
143 static OptionalParseResult parseOptionalOperandAndType(OpAsmParser &parser,
144                                                        OperationState &result) {
145   if (succeeded(parser.parseOptionalLParen())) {
146     return failure(parseOperandAndType(parser, result) || parser.parseRParen());
147   }
148   return llvm::None;
149 }
150 
151 /// Parse optional operand with its type prefixed with prefixKeyword `=`.
152 /// Example:
153 ///   num=%gangNum: i32
154 static OptionalParseResult parserOptionalOperandAndTypeWithPrefix(
155     OpAsmParser &parser, OperationState &result, StringRef prefixKeyword) {
156   if (succeeded(parser.parseOptionalKeyword(prefixKeyword))) {
157     parser.parseEqual();
158     return parseOperandAndType(parser, result);
159   }
160   return llvm::None;
161 }
162 
163 static bool isComputeOperation(Operation *op) {
164   return isa<acc::ParallelOp>(op) || isa<acc::LoopOp>(op);
165 }
166 
167 namespace {
168 /// Pattern to remove operation without region that have constant false `ifCond`
169 /// and remove the condition from the operation if the `ifCond` is a true
170 /// constant.
171 template <typename OpTy>
172 struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
173   using OpRewritePattern<OpTy>::OpRewritePattern;
174 
175   LogicalResult matchAndRewrite(OpTy op,
176                                 PatternRewriter &rewriter) const override {
177     // Early return if there is no condition.
178     if (!op.ifCond())
179       return success();
180 
181     auto constOp = op.ifCond().template getDefiningOp<arith::ConstantOp>();
182     if (constOp && constOp.getValue().template cast<IntegerAttr>().getInt())
183       rewriter.updateRootInPlace(op, [&]() { op.ifCondMutable().erase(0); });
184     else if (constOp)
185       rewriter.eraseOp(op);
186 
187     return success();
188   }
189 };
190 } // namespace
191 
192 //===----------------------------------------------------------------------===//
193 // ParallelOp
194 //===----------------------------------------------------------------------===//
195 
196 /// Parse acc.parallel operation
197 /// operation := `acc.parallel` `async` `(` index `)`?
198 ///                             `wait` `(` index-list `)`?
199 ///                             `num_gangs` `(` value `)`?
200 ///                             `num_workers` `(` value `)`?
201 ///                             `vector_length` `(` value `)`?
202 ///                             `if` `(` value `)`?
203 ///                             `self` `(` value `)`?
204 ///                             `reduction` `(` value-list `)`?
205 ///                             `copy` `(` value-list `)`?
206 ///                             `copyin` `(` value-list `)`?
207 ///                             `copyin_readonly` `(` value-list `)`?
208 ///                             `copyout` `(` value-list `)`?
209 ///                             `copyout_zero` `(` value-list `)`?
210 ///                             `create` `(` value-list `)`?
211 ///                             `create_zero` `(` value-list `)`?
212 ///                             `no_create` `(` value-list `)`?
213 ///                             `present` `(` value-list `)`?
214 ///                             `deviceptr` `(` value-list `)`?
215 ///                             `attach` `(` value-list `)`?
216 ///                             `private` `(` value-list `)`?
217 ///                             `firstprivate` `(` value-list `)`?
218 ///                             region attr-dict?
219 static ParseResult parseParallelOp(OpAsmParser &parser,
220                                    OperationState &result) {
221   Builder &builder = parser.getBuilder();
222   SmallVector<OpAsmParser::OperandType, 8> privateOperands,
223       firstprivateOperands, copyOperands, copyinOperands,
224       copyinReadonlyOperands, copyoutOperands, copyoutZeroOperands,
225       createOperands, createZeroOperands, noCreateOperands, presentOperands,
226       devicePtrOperands, attachOperands, waitOperands, reductionOperands;
227   SmallVector<Type, 8> waitOperandTypes, reductionOperandTypes,
228       copyOperandTypes, copyinOperandTypes, copyinReadonlyOperandTypes,
229       copyoutOperandTypes, copyoutZeroOperandTypes, createOperandTypes,
230       createZeroOperandTypes, noCreateOperandTypes, presentOperandTypes,
231       deviceptrOperandTypes, attachOperandTypes, privateOperandTypes,
232       firstprivateOperandTypes;
233 
234   SmallVector<Type, 8> operandTypes;
235   OpAsmParser::OperandType ifCond, selfCond;
236   bool hasIfCond = false, hasSelfCond = false;
237   OptionalParseResult async, numGangs, numWorkers, vectorLength;
238   Type i1Type = builder.getI1Type();
239 
240   // async()?
241   async = parseOptionalOperandAndType(parser, ParallelOp::getAsyncKeyword(),
242                                       result);
243   if (async.hasValue() && failed(*async))
244     return failure();
245 
246   // wait()?
247   if (failed(parseOperandList(parser, ParallelOp::getWaitKeyword(),
248                               waitOperands, waitOperandTypes, result)))
249     return failure();
250 
251   // num_gangs(value)?
252   numGangs = parseOptionalOperandAndType(
253       parser, ParallelOp::getNumGangsKeyword(), result);
254   if (numGangs.hasValue() && failed(*numGangs))
255     return failure();
256 
257   // num_workers(value)?
258   numWorkers = parseOptionalOperandAndType(
259       parser, ParallelOp::getNumWorkersKeyword(), result);
260   if (numWorkers.hasValue() && failed(*numWorkers))
261     return failure();
262 
263   // vector_length(value)?
264   vectorLength = parseOptionalOperandAndType(
265       parser, ParallelOp::getVectorLengthKeyword(), result);
266   if (vectorLength.hasValue() && failed(*vectorLength))
267     return failure();
268 
269   // if()?
270   if (failed(parseOptionalOperand(parser, ParallelOp::getIfKeyword(), ifCond,
271                                   i1Type, hasIfCond, result)))
272     return failure();
273 
274   // self()?
275   if (failed(parseOptionalOperand(parser, ParallelOp::getSelfKeyword(),
276                                   selfCond, i1Type, hasSelfCond, result)))
277     return failure();
278 
279   // reduction()?
280   if (failed(parseOperandList(parser, ParallelOp::getReductionKeyword(),
281                               reductionOperands, reductionOperandTypes,
282                               result)))
283     return failure();
284 
285   // copy()?
286   if (failed(parseOperandList(parser, ParallelOp::getCopyKeyword(),
287                               copyOperands, copyOperandTypes, result)))
288     return failure();
289 
290   // copyin()?
291   if (failed(parseOperandList(parser, ParallelOp::getCopyinKeyword(),
292                               copyinOperands, copyinOperandTypes, result)))
293     return failure();
294 
295   // copyin_readonly()?
296   if (failed(parseOperandList(parser, ParallelOp::getCopyinReadonlyKeyword(),
297                               copyinReadonlyOperands,
298                               copyinReadonlyOperandTypes, result)))
299     return failure();
300 
301   // copyout()?
302   if (failed(parseOperandList(parser, ParallelOp::getCopyoutKeyword(),
303                               copyoutOperands, copyoutOperandTypes, result)))
304     return failure();
305 
306   // copyout_zero()?
307   if (failed(parseOperandList(parser, ParallelOp::getCopyoutZeroKeyword(),
308                               copyoutZeroOperands, copyoutZeroOperandTypes,
309                               result)))
310     return failure();
311 
312   // create()?
313   if (failed(parseOperandList(parser, ParallelOp::getCreateKeyword(),
314                               createOperands, createOperandTypes, result)))
315     return failure();
316 
317   // create_zero()?
318   if (failed(parseOperandList(parser, ParallelOp::getCreateZeroKeyword(),
319                               createZeroOperands, createZeroOperandTypes,
320                               result)))
321     return failure();
322 
323   // no_create()?
324   if (failed(parseOperandList(parser, ParallelOp::getNoCreateKeyword(),
325                               noCreateOperands, noCreateOperandTypes, result)))
326     return failure();
327 
328   // present()?
329   if (failed(parseOperandList(parser, ParallelOp::getPresentKeyword(),
330                               presentOperands, presentOperandTypes, result)))
331     return failure();
332 
333   // deviceptr()?
334   if (failed(parseOperandList(parser, ParallelOp::getDevicePtrKeyword(),
335                               devicePtrOperands, deviceptrOperandTypes,
336                               result)))
337     return failure();
338 
339   // attach()?
340   if (failed(parseOperandList(parser, ParallelOp::getAttachKeyword(),
341                               attachOperands, attachOperandTypes, result)))
342     return failure();
343 
344   // private()?
345   if (failed(parseOperandList(parser, ParallelOp::getPrivateKeyword(),
346                               privateOperands, privateOperandTypes, result)))
347     return failure();
348 
349   // firstprivate()?
350   if (failed(parseOperandList(parser, ParallelOp::getFirstPrivateKeyword(),
351                               firstprivateOperands, firstprivateOperandTypes,
352                               result)))
353     return failure();
354 
355   // Parallel op region
356   if (failed(parseRegions<ParallelOp>(parser, result)))
357     return failure();
358 
359   result.addAttribute(
360       ParallelOp::getOperandSegmentSizeAttr(),
361       builder.getI32VectorAttr(
362           {static_cast<int32_t>(async.hasValue() ? 1 : 0),
363            static_cast<int32_t>(waitOperands.size()),
364            static_cast<int32_t>(numGangs.hasValue() ? 1 : 0),
365            static_cast<int32_t>(numWorkers.hasValue() ? 1 : 0),
366            static_cast<int32_t>(vectorLength.hasValue() ? 1 : 0),
367            static_cast<int32_t>(hasIfCond ? 1 : 0),
368            static_cast<int32_t>(hasSelfCond ? 1 : 0),
369            static_cast<int32_t>(reductionOperands.size()),
370            static_cast<int32_t>(copyOperands.size()),
371            static_cast<int32_t>(copyinOperands.size()),
372            static_cast<int32_t>(copyinReadonlyOperands.size()),
373            static_cast<int32_t>(copyoutOperands.size()),
374            static_cast<int32_t>(copyoutZeroOperands.size()),
375            static_cast<int32_t>(createOperands.size()),
376            static_cast<int32_t>(createZeroOperands.size()),
377            static_cast<int32_t>(noCreateOperands.size()),
378            static_cast<int32_t>(presentOperands.size()),
379            static_cast<int32_t>(devicePtrOperands.size()),
380            static_cast<int32_t>(attachOperands.size()),
381            static_cast<int32_t>(privateOperands.size()),
382            static_cast<int32_t>(firstprivateOperands.size())}));
383 
384   // Additional attributes
385   if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
386     return failure();
387 
388   return success();
389 }
390 
391 static void print(OpAsmPrinter &printer, ParallelOp &op) {
392   // async()?
393   if (Value async = op.async())
394     printer << " " << ParallelOp::getAsyncKeyword() << "(" << async << ": "
395             << async.getType() << ")";
396 
397   // wait()?
398   printOperandList(op.waitOperands(), ParallelOp::getWaitKeyword(), printer);
399 
400   // num_gangs()?
401   if (Value numGangs = op.numGangs())
402     printer << " " << ParallelOp::getNumGangsKeyword() << "(" << numGangs
403             << ": " << numGangs.getType() << ")";
404 
405   // num_workers()?
406   if (Value numWorkers = op.numWorkers())
407     printer << " " << ParallelOp::getNumWorkersKeyword() << "(" << numWorkers
408             << ": " << numWorkers.getType() << ")";
409 
410   // vector_length()?
411   if (Value vectorLength = op.vectorLength())
412     printer << " " << ParallelOp::getVectorLengthKeyword() << "("
413             << vectorLength << ": " << vectorLength.getType() << ")";
414 
415   // if()?
416   if (Value ifCond = op.ifCond())
417     printer << " " << ParallelOp::getIfKeyword() << "(" << ifCond << ")";
418 
419   // self()?
420   if (Value selfCond = op.selfCond())
421     printer << " " << ParallelOp::getSelfKeyword() << "(" << selfCond << ")";
422 
423   // reduction()?
424   printOperandList(op.reductionOperands(), ParallelOp::getReductionKeyword(),
425                    printer);
426 
427   // copy()?
428   printOperandList(op.copyOperands(), ParallelOp::getCopyKeyword(), printer);
429 
430   // copyin()?
431   printOperandList(op.copyinOperands(), ParallelOp::getCopyinKeyword(),
432                    printer);
433 
434   // copyin_readonly()?
435   printOperandList(op.copyinReadonlyOperands(),
436                    ParallelOp::getCopyinReadonlyKeyword(), printer);
437 
438   // copyout()?
439   printOperandList(op.copyoutOperands(), ParallelOp::getCopyoutKeyword(),
440                    printer);
441 
442   // copyout_zero()?
443   printOperandList(op.copyoutZeroOperands(),
444                    ParallelOp::getCopyoutZeroKeyword(), printer);
445 
446   // create()?
447   printOperandList(op.createOperands(), ParallelOp::getCreateKeyword(),
448                    printer);
449 
450   // create_zero()?
451   printOperandList(op.createZeroOperands(), ParallelOp::getCreateZeroKeyword(),
452                    printer);
453 
454   // no_create()?
455   printOperandList(op.noCreateOperands(), ParallelOp::getNoCreateKeyword(),
456                    printer);
457 
458   // present()?
459   printOperandList(op.presentOperands(), ParallelOp::getPresentKeyword(),
460                    printer);
461 
462   // deviceptr()?
463   printOperandList(op.devicePtrOperands(), ParallelOp::getDevicePtrKeyword(),
464                    printer);
465 
466   // attach()?
467   printOperandList(op.attachOperands(), ParallelOp::getAttachKeyword(),
468                    printer);
469 
470   // private()?
471   printOperandList(op.gangPrivateOperands(), ParallelOp::getPrivateKeyword(),
472                    printer);
473 
474   // firstprivate()?
475   printOperandList(op.gangFirstPrivateOperands(),
476                    ParallelOp::getFirstPrivateKeyword(), printer);
477 
478   printer << ' ';
479   printer.printRegion(op.region(),
480                       /*printEntryBlockArgs=*/false,
481                       /*printBlockTerminators=*/true);
482   printer.printOptionalAttrDictWithKeyword(
483       op->getAttrs(), ParallelOp::getOperandSegmentSizeAttr());
484 }
485 
486 unsigned ParallelOp::getNumDataOperands() {
487   return reductionOperands().size() + copyOperands().size() +
488          copyinOperands().size() + copyinReadonlyOperands().size() +
489          copyoutOperands().size() + copyoutZeroOperands().size() +
490          createOperands().size() + createZeroOperands().size() +
491          noCreateOperands().size() + presentOperands().size() +
492          devicePtrOperands().size() + attachOperands().size() +
493          gangPrivateOperands().size() + gangFirstPrivateOperands().size();
494 }
495 
496 Value ParallelOp::getDataOperand(unsigned i) {
497   unsigned numOptional = async() ? 1 : 0;
498   numOptional += numGangs() ? 1 : 0;
499   numOptional += numWorkers() ? 1 : 0;
500   numOptional += vectorLength() ? 1 : 0;
501   numOptional += ifCond() ? 1 : 0;
502   numOptional += selfCond() ? 1 : 0;
503   return getOperand(waitOperands().size() + numOptional + i);
504 }
505 
506 //===----------------------------------------------------------------------===//
507 // LoopOp
508 //===----------------------------------------------------------------------===//
509 
510 /// Parse acc.loop operation
511 /// operation := `acc.loop`
512 ///              (`gang` ( `(` (`num=` value)? (`,` `static=` value `)`)? )? )?
513 ///              (`vector` ( `(` value `)` )? )? (`worker` (`(` value `)`)? )?
514 ///              (`vector_length` `(` value `)`)?
515 ///              (`tile` `(` value-list `)`)?
516 ///              (`private` `(` value-list `)`)?
517 ///              (`reduction` `(` value-list `)`)?
518 ///              region attr-dict?
519 static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &result) {
520   Builder &builder = parser.getBuilder();
521   unsigned executionMapping = OpenACCExecMapping::NONE;
522   SmallVector<Type, 8> operandTypes;
523   SmallVector<OpAsmParser::OperandType, 8> privateOperands, reductionOperands;
524   SmallVector<OpAsmParser::OperandType, 8> tileOperands;
525   OptionalParseResult gangNum, gangStatic, worker, vector;
526 
527   // gang?
528   if (succeeded(parser.parseOptionalKeyword(LoopOp::getGangKeyword())))
529     executionMapping |= OpenACCExecMapping::GANG;
530 
531   // optional gang operand
532   if (succeeded(parser.parseOptionalLParen())) {
533     gangNum = parserOptionalOperandAndTypeWithPrefix(
534         parser, result, LoopOp::getGangNumKeyword());
535     if (gangNum.hasValue() && failed(*gangNum))
536       return failure();
537     parser.parseOptionalComma();
538     gangStatic = parserOptionalOperandAndTypeWithPrefix(
539         parser, result, LoopOp::getGangStaticKeyword());
540     if (gangStatic.hasValue() && failed(*gangStatic))
541       return failure();
542     parser.parseOptionalComma();
543     if (failed(parser.parseRParen()))
544       return failure();
545   }
546 
547   // worker?
548   if (succeeded(parser.parseOptionalKeyword(LoopOp::getWorkerKeyword())))
549     executionMapping |= OpenACCExecMapping::WORKER;
550 
551   // optional worker operand
552   worker = parseOptionalOperandAndType(parser, result);
553   if (worker.hasValue() && failed(*worker))
554     return failure();
555 
556   // vector?
557   if (succeeded(parser.parseOptionalKeyword(LoopOp::getVectorKeyword())))
558     executionMapping |= OpenACCExecMapping::VECTOR;
559 
560   // optional vector operand
561   vector = parseOptionalOperandAndType(parser, result);
562   if (vector.hasValue() && failed(*vector))
563     return failure();
564 
565   // tile()?
566   if (failed(parseOperandList(parser, LoopOp::getTileKeyword(), tileOperands,
567                               operandTypes, result)))
568     return failure();
569 
570   // private()?
571   if (failed(parseOperandList(parser, LoopOp::getPrivateKeyword(),
572                               privateOperands, operandTypes, result)))
573     return failure();
574 
575   // reduction()?
576   if (failed(parseOperandList(parser, LoopOp::getReductionKeyword(),
577                               reductionOperands, operandTypes, result)))
578     return failure();
579 
580   if (executionMapping != acc::OpenACCExecMapping::NONE)
581     result.addAttribute(LoopOp::getExecutionMappingAttrName(),
582                         builder.getI64IntegerAttr(executionMapping));
583 
584   // Parse optional results in case there is a reduce.
585   if (parser.parseOptionalArrowTypeList(result.types))
586     return failure();
587 
588   if (failed(parseRegions<LoopOp>(parser, result)))
589     return failure();
590 
591   result.addAttribute(LoopOp::getOperandSegmentSizeAttr(),
592                       builder.getI32VectorAttr(
593                           {static_cast<int32_t>(gangNum.hasValue() ? 1 : 0),
594                            static_cast<int32_t>(gangStatic.hasValue() ? 1 : 0),
595                            static_cast<int32_t>(worker.hasValue() ? 1 : 0),
596                            static_cast<int32_t>(vector.hasValue() ? 1 : 0),
597                            static_cast<int32_t>(tileOperands.size()),
598                            static_cast<int32_t>(privateOperands.size()),
599                            static_cast<int32_t>(reductionOperands.size())}));
600 
601   if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
602     return failure();
603 
604   return success();
605 }
606 
607 static void print(OpAsmPrinter &printer, LoopOp &op) {
608   unsigned execMapping = op.exec_mapping();
609   if (execMapping & OpenACCExecMapping::GANG) {
610     printer << " " << LoopOp::getGangKeyword();
611     Value gangNum = op.gangNum();
612     Value gangStatic = op.gangStatic();
613 
614     // Print optional gang operands
615     if (gangNum || gangStatic) {
616       printer << "(";
617       if (gangNum) {
618         printer << LoopOp::getGangNumKeyword() << "=" << gangNum << ": "
619                 << gangNum.getType();
620         if (gangStatic)
621           printer << ", ";
622       }
623       if (gangStatic)
624         printer << LoopOp::getGangStaticKeyword() << "=" << gangStatic << ": "
625                 << gangStatic.getType();
626       printer << ")";
627     }
628   }
629 
630   if (execMapping & OpenACCExecMapping::WORKER) {
631     printer << " " << LoopOp::getWorkerKeyword();
632 
633     // Print optional worker operand if present
634     if (Value workerNum = op.workerNum())
635       printer << "(" << workerNum << ": " << workerNum.getType() << ")";
636   }
637 
638   if (execMapping & OpenACCExecMapping::VECTOR) {
639     printer << " " << LoopOp::getVectorKeyword();
640 
641     // Print optional vector operand if present
642     if (Value vectorLength = op.vectorLength())
643       printer << "(" << vectorLength << ": " << vectorLength.getType() << ")";
644   }
645 
646   // tile()?
647   printOperandList(op.tileOperands(), LoopOp::getTileKeyword(), printer);
648 
649   // private()?
650   printOperandList(op.privateOperands(), LoopOp::getPrivateKeyword(), printer);
651 
652   // reduction()?
653   printOperandList(op.reductionOperands(), LoopOp::getReductionKeyword(),
654                    printer);
655 
656   if (op.getNumResults() > 0)
657     printer << " -> (" << op.getResultTypes() << ")";
658 
659   printer << ' ';
660   printer.printRegion(op.region(),
661                       /*printEntryBlockArgs=*/false,
662                       /*printBlockTerminators=*/true);
663 
664   printer.printOptionalAttrDictWithKeyword(
665       op->getAttrs(), {LoopOp::getExecutionMappingAttrName(),
666                        LoopOp::getOperandSegmentSizeAttr()});
667 }
668 
669 static LogicalResult verifyLoopOp(acc::LoopOp loopOp) {
670   // auto, independent and seq attribute are mutually exclusive.
671   if ((loopOp.auto_() && (loopOp.independent() || loopOp.seq())) ||
672       (loopOp.independent() && loopOp.seq())) {
673     loopOp.emitError("only one of " + acc::LoopOp::getAutoAttrName() + ", " +
674                      acc::LoopOp::getIndependentAttrName() + ", " +
675                      acc::LoopOp::getSeqAttrName() +
676                      " can be present at the same time");
677     return failure();
678   }
679 
680   // Gang, worker and vector are incompatible with seq.
681   if (loopOp.seq() && loopOp.exec_mapping() != OpenACCExecMapping::NONE) {
682     loopOp.emitError("gang, worker or vector cannot appear with the seq attr");
683     return failure();
684   }
685 
686   // Check non-empty body().
687   if (loopOp.region().empty()) {
688     loopOp.emitError("expected non-empty body.");
689     return failure();
690   }
691 
692   return success();
693 }
694 
695 //===----------------------------------------------------------------------===//
696 // DataOp
697 //===----------------------------------------------------------------------===//
698 
699 static LogicalResult verify(acc::DataOp dataOp) {
700   // 2.6.5. Data Construct restriction
701   // At least one copy, copyin, copyout, create, no_create, present, deviceptr,
702   // attach, or default clause must appear on a data construct.
703   if (dataOp.getOperands().empty() && !dataOp.defaultAttr())
704     return dataOp.emitError("at least one operand or the default attribute "
705                             "must appear on the data operation");
706   return success();
707 }
708 
709 unsigned DataOp::getNumDataOperands() {
710   return copyOperands().size() + copyinOperands().size() +
711          copyinReadonlyOperands().size() + copyoutOperands().size() +
712          copyoutZeroOperands().size() + createOperands().size() +
713          createZeroOperands().size() + noCreateOperands().size() +
714          presentOperands().size() + deviceptrOperands().size() +
715          attachOperands().size();
716 }
717 
718 Value DataOp::getDataOperand(unsigned i) {
719   unsigned numOptional = ifCond() ? 1 : 0;
720   return getOperand(numOptional + i);
721 }
722 
723 //===----------------------------------------------------------------------===//
724 // ExitDataOp
725 //===----------------------------------------------------------------------===//
726 
727 static LogicalResult verify(acc::ExitDataOp op) {
728   // 2.6.6. Data Exit Directive restriction
729   // At least one copyout, delete, or detach clause must appear on an exit data
730   // directive.
731   if (op.copyoutOperands().empty() && op.deleteOperands().empty() &&
732       op.detachOperands().empty())
733     return op.emitError(
734         "at least one operand in copyout, delete or detach must appear on the "
735         "exit data operation");
736 
737   // The async attribute represent the async clause without value. Therefore the
738   // attribute and operand cannot appear at the same time.
739   if (op.asyncOperand() && op.async())
740     return op.emitError("async attribute cannot appear with asyncOperand");
741 
742   // The wait attribute represent the wait clause without values. Therefore the
743   // attribute and operands cannot appear at the same time.
744   if (!op.waitOperands().empty() && op.wait())
745     return op.emitError("wait attribute cannot appear with waitOperands");
746 
747   if (op.waitDevnum() && op.waitOperands().empty())
748     return op.emitError("wait_devnum cannot appear without waitOperands");
749 
750   return success();
751 }
752 
753 unsigned ExitDataOp::getNumDataOperands() {
754   return copyoutOperands().size() + deleteOperands().size() +
755          detachOperands().size();
756 }
757 
758 Value ExitDataOp::getDataOperand(unsigned i) {
759   unsigned numOptional = ifCond() ? 1 : 0;
760   numOptional += asyncOperand() ? 1 : 0;
761   numOptional += waitDevnum() ? 1 : 0;
762   return getOperand(waitOperands().size() + numOptional + i);
763 }
764 
765 void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
766                                              MLIRContext *context) {
767   results.add<RemoveConstantIfCondition<ExitDataOp>>(context);
768 }
769 
770 //===----------------------------------------------------------------------===//
771 // EnterDataOp
772 //===----------------------------------------------------------------------===//
773 
774 static LogicalResult verify(acc::EnterDataOp op) {
775   // 2.6.6. Data Enter Directive restriction
776   // At least one copyin, create, or attach clause must appear on an enter data
777   // directive.
778   if (op.copyinOperands().empty() && op.createOperands().empty() &&
779       op.createZeroOperands().empty() && op.attachOperands().empty())
780     return op.emitError(
781         "at least one operand in copyin, create, "
782         "create_zero or attach must appear on the enter data operation");
783 
784   // The async attribute represent the async clause without value. Therefore the
785   // attribute and operand cannot appear at the same time.
786   if (op.asyncOperand() && op.async())
787     return op.emitError("async attribute cannot appear with asyncOperand");
788 
789   // The wait attribute represent the wait clause without values. Therefore the
790   // attribute and operands cannot appear at the same time.
791   if (!op.waitOperands().empty() && op.wait())
792     return op.emitError("wait attribute cannot appear with waitOperands");
793 
794   if (op.waitDevnum() && op.waitOperands().empty())
795     return op.emitError("wait_devnum cannot appear without waitOperands");
796 
797   return success();
798 }
799 
800 unsigned EnterDataOp::getNumDataOperands() {
801   return copyinOperands().size() + createOperands().size() +
802          createZeroOperands().size() + attachOperands().size();
803 }
804 
805 Value EnterDataOp::getDataOperand(unsigned i) {
806   unsigned numOptional = ifCond() ? 1 : 0;
807   numOptional += asyncOperand() ? 1 : 0;
808   numOptional += waitDevnum() ? 1 : 0;
809   return getOperand(waitOperands().size() + numOptional + i);
810 }
811 
812 void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
813                                               MLIRContext *context) {
814   results.add<RemoveConstantIfCondition<EnterDataOp>>(context);
815 }
816 
817 //===----------------------------------------------------------------------===//
818 // InitOp
819 //===----------------------------------------------------------------------===//
820 
821 static LogicalResult verify(acc::InitOp initOp) {
822   Operation *currOp = initOp;
823   while ((currOp = currOp->getParentOp())) {
824     if (isComputeOperation(currOp))
825       return initOp.emitOpError("cannot be nested in a compute operation");
826   }
827   return success();
828 }
829 
830 //===----------------------------------------------------------------------===//
831 // ShutdownOp
832 //===----------------------------------------------------------------------===//
833 
834 static LogicalResult verify(acc::ShutdownOp op) {
835   Operation *currOp = op;
836   while ((currOp = currOp->getParentOp())) {
837     if (isComputeOperation(currOp))
838       return op.emitOpError("cannot be nested in a compute operation");
839   }
840   return success();
841 }
842 
843 //===----------------------------------------------------------------------===//
844 // UpdateOp
845 //===----------------------------------------------------------------------===//
846 
847 static LogicalResult verify(acc::UpdateOp updateOp) {
848   // At least one of host or device should have a value.
849   if (updateOp.hostOperands().empty() && updateOp.deviceOperands().empty())
850     return updateOp.emitError("at least one value must be present in"
851                               " hostOperands or deviceOperands");
852 
853   // The async attribute represent the async clause without value. Therefore the
854   // attribute and operand cannot appear at the same time.
855   if (updateOp.asyncOperand() && updateOp.async())
856     return updateOp.emitError("async attribute cannot appear with "
857                               " asyncOperand");
858 
859   // The wait attribute represent the wait clause without values. Therefore the
860   // attribute and operands cannot appear at the same time.
861   if (!updateOp.waitOperands().empty() && updateOp.wait())
862     return updateOp.emitError("wait attribute cannot appear with waitOperands");
863 
864   if (updateOp.waitDevnum() && updateOp.waitOperands().empty())
865     return updateOp.emitError("wait_devnum cannot appear without waitOperands");
866 
867   return success();
868 }
869 
870 unsigned UpdateOp::getNumDataOperands() {
871   return hostOperands().size() + deviceOperands().size();
872 }
873 
874 Value UpdateOp::getDataOperand(unsigned i) {
875   unsigned numOptional = asyncOperand() ? 1 : 0;
876   numOptional += waitDevnum() ? 1 : 0;
877   numOptional += ifCond() ? 1 : 0;
878   return getOperand(waitOperands().size() + deviceTypeOperands().size() +
879                     numOptional + i);
880 }
881 
882 void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results,
883                                            MLIRContext *context) {
884   results.add<RemoveConstantIfCondition<UpdateOp>>(context);
885 }
886 
887 //===----------------------------------------------------------------------===//
888 // WaitOp
889 //===----------------------------------------------------------------------===//
890 
891 static LogicalResult verify(acc::WaitOp waitOp) {
892   // The async attribute represent the async clause without value. Therefore the
893   // attribute and operand cannot appear at the same time.
894   if (waitOp.asyncOperand() && waitOp.async())
895     return waitOp.emitError("async attribute cannot appear with asyncOperand");
896 
897   if (waitOp.waitDevnum() && waitOp.waitOperands().empty())
898     return waitOp.emitError("wait_devnum cannot appear without waitOperands");
899 
900   return success();
901 }
902 
903 #define GET_OP_CLASSES
904 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
905 
906 #define GET_ATTRDEF_CLASSES
907 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
908