1 //===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 // =============================================================================
8 
9 #include "mlir/Dialect/OpenACC/OpenACC.h"
10 #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "mlir/IR/DialectImplementation.h"
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/OpImplementation.h"
16 #include "mlir/Transforms/DialectConversion.h"
17 #include "llvm/ADT/TypeSwitch.h"
18 
19 using namespace mlir;
20 using namespace acc;
21 
22 #include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
23 
24 //===----------------------------------------------------------------------===//
25 // OpenACC operations
26 //===----------------------------------------------------------------------===//
27 
initialize()28 void OpenACCDialect::initialize() {
29   addOperations<
30 #define GET_OP_LIST
31 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
32       >();
33   addAttributes<
34 #define GET_ATTRDEF_LIST
35 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
36       >();
37 }
38 
39 template <typename StructureOp>
parseRegions(OpAsmParser & parser,OperationState & state,unsigned nRegions=1)40 static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
41                                 unsigned nRegions = 1) {
42 
43   SmallVector<Region *, 2> regions;
44   for (unsigned i = 0; i < nRegions; ++i)
45     regions.push_back(state.addRegion());
46 
47   for (Region *region : regions) {
48     if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{}))
49       return failure();
50   }
51 
52   return success();
53 }
54 
55 static ParseResult
parseOperandList(OpAsmParser & parser,StringRef keyword,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & args,SmallVectorImpl<Type> & argTypes,OperationState & result)56 parseOperandList(OpAsmParser &parser, StringRef keyword,
57                  SmallVectorImpl<OpAsmParser::UnresolvedOperand> &args,
58                  SmallVectorImpl<Type> &argTypes, OperationState &result) {
59   if (failed(parser.parseOptionalKeyword(keyword)))
60     return success();
61 
62   if (failed(parser.parseLParen()))
63     return failure();
64 
65   // Exit early if the list is empty.
66   if (succeeded(parser.parseOptionalRParen()))
67     return success();
68 
69   if (failed(parser.parseCommaSeparatedList([&]() {
70         OpAsmParser::UnresolvedOperand arg;
71         Type type;
72 
73         if (parser.parseOperand(arg, /*allowResultNumber=*/false) ||
74             parser.parseColonType(type))
75           return failure();
76 
77         args.push_back(arg);
78         argTypes.push_back(type);
79         return success();
80       })) ||
81       failed(parser.parseRParen()))
82     return failure();
83 
84   return parser.resolveOperands(args, argTypes, parser.getCurrentLocation(),
85                                 result.operands);
86 }
87 
printOperandList(Operation::operand_range operands,StringRef listName,OpAsmPrinter & printer)88 static void printOperandList(Operation::operand_range operands,
89                              StringRef listName, OpAsmPrinter &printer) {
90 
91   if (!operands.empty()) {
92     printer << " " << listName << "(";
93     llvm::interleaveComma(operands, printer, [&](Value op) {
94       printer << op << ": " << op.getType();
95     });
96     printer << ")";
97   }
98 }
99 
parseOptionalOperand(OpAsmParser & parser,StringRef keyword,OpAsmParser::UnresolvedOperand & operand,Type type,bool & hasOptional,OperationState & result)100 static ParseResult parseOptionalOperand(OpAsmParser &parser, StringRef keyword,
101                                         OpAsmParser::UnresolvedOperand &operand,
102                                         Type type, bool &hasOptional,
103                                         OperationState &result) {
104   hasOptional = false;
105   if (succeeded(parser.parseOptionalKeyword(keyword))) {
106     hasOptional = true;
107     if (parser.parseLParen() || parser.parseOperand(operand) ||
108         parser.resolveOperand(operand, type, result.operands) ||
109         parser.parseRParen())
110       return failure();
111   }
112   return success();
113 }
114 
parseOperandAndType(OpAsmParser & parser,OperationState & result)115 static ParseResult parseOperandAndType(OpAsmParser &parser,
116                                        OperationState &result) {
117   OpAsmParser::UnresolvedOperand operand;
118   Type type;
119   if (parser.parseOperand(operand) || parser.parseColonType(type) ||
120       parser.resolveOperand(operand, type, result.operands))
121     return failure();
122   return success();
123 }
124 
125 /// Parse optional operand and its type wrapped in parenthesis prefixed with
126 /// a keyword.
127 /// Example:
128 ///   keyword `(` %vectorLength: i64 `)`
parseOptionalOperandAndType(OpAsmParser & parser,StringRef keyword,OperationState & result)129 static OptionalParseResult parseOptionalOperandAndType(OpAsmParser &parser,
130                                                        StringRef keyword,
131                                                        OperationState &result) {
132   OpAsmParser::UnresolvedOperand operand;
133   if (succeeded(parser.parseOptionalKeyword(keyword))) {
134     return failure(parser.parseLParen() ||
135                    parseOperandAndType(parser, result) || parser.parseRParen());
136   }
137   return llvm::None;
138 }
139 
140 /// Parse optional operand and its type wrapped in parenthesis.
141 /// Example:
142 ///   `(` %vectorLength: i64 `)`
parseOptionalOperandAndType(OpAsmParser & parser,OperationState & result)143 static OptionalParseResult parseOptionalOperandAndType(OpAsmParser &parser,
144                                                        OperationState &result) {
145   if (succeeded(parser.parseOptionalLParen())) {
146     return failure(parseOperandAndType(parser, result) || parser.parseRParen());
147   }
148   return llvm::None;
149 }
150 
151 /// Parse optional operand with its type prefixed with prefixKeyword `=`.
152 /// Example:
153 ///   num=%gangNum: i32
parserOptionalOperandAndTypeWithPrefix(OpAsmParser & parser,OperationState & result,StringRef prefixKeyword)154 static OptionalParseResult parserOptionalOperandAndTypeWithPrefix(
155     OpAsmParser &parser, OperationState &result, StringRef prefixKeyword) {
156   if (succeeded(parser.parseOptionalKeyword(prefixKeyword))) {
157     if (parser.parseEqual() || parseOperandAndType(parser, result))
158       return failure();
159     return success();
160   }
161   return llvm::None;
162 }
163 
isComputeOperation(Operation * op)164 static bool isComputeOperation(Operation *op) {
165   return isa<acc::ParallelOp>(op) || isa<acc::LoopOp>(op);
166 }
167 
168 namespace {
169 /// Pattern to remove operation without region that have constant false `ifCond`
170 /// and remove the condition from the operation if the `ifCond` is a true
171 /// constant.
172 template <typename OpTy>
173 struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
174   using OpRewritePattern<OpTy>::OpRewritePattern;
175 
matchAndRewrite__anon648eb3690311::RemoveConstantIfCondition176   LogicalResult matchAndRewrite(OpTy op,
177                                 PatternRewriter &rewriter) const override {
178     // Early return if there is no condition.
179     Value ifCond = op.ifCond();
180     if (!ifCond)
181       return success();
182 
183     IntegerAttr constAttr;
184     if (matchPattern(ifCond, m_Constant(&constAttr))) {
185       if (constAttr.getInt())
186         rewriter.updateRootInPlace(op, [&]() { op.ifCondMutable().erase(0); });
187       else
188         rewriter.eraseOp(op);
189     }
190 
191     return success();
192   }
193 };
194 } // namespace
195 
196 //===----------------------------------------------------------------------===//
197 // ParallelOp
198 //===----------------------------------------------------------------------===//
199 
200 /// Parse acc.parallel operation
201 /// operation := `acc.parallel` `async` `(` index `)`?
202 ///                             `wait` `(` index-list `)`?
203 ///                             `num_gangs` `(` value `)`?
204 ///                             `num_workers` `(` value `)`?
205 ///                             `vector_length` `(` value `)`?
206 ///                             `if` `(` value `)`?
207 ///                             `self` `(` value `)`?
208 ///                             `reduction` `(` value-list `)`?
209 ///                             `copy` `(` value-list `)`?
210 ///                             `copyin` `(` value-list `)`?
211 ///                             `copyin_readonly` `(` value-list `)`?
212 ///                             `copyout` `(` value-list `)`?
213 ///                             `copyout_zero` `(` value-list `)`?
214 ///                             `create` `(` value-list `)`?
215 ///                             `create_zero` `(` value-list `)`?
216 ///                             `no_create` `(` value-list `)`?
217 ///                             `present` `(` value-list `)`?
218 ///                             `deviceptr` `(` value-list `)`?
219 ///                             `attach` `(` value-list `)`?
220 ///                             `private` `(` value-list `)`?
221 ///                             `firstprivate` `(` value-list `)`?
222 ///                             region attr-dict?
parse(OpAsmParser & parser,OperationState & result)223 ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
224   Builder &builder = parser.getBuilder();
225   SmallVector<OpAsmParser::UnresolvedOperand, 8> privateOperands,
226       firstprivateOperands, copyOperands, copyinOperands,
227       copyinReadonlyOperands, copyoutOperands, copyoutZeroOperands,
228       createOperands, createZeroOperands, noCreateOperands, presentOperands,
229       devicePtrOperands, attachOperands, waitOperands, reductionOperands;
230   SmallVector<Type, 8> waitOperandTypes, reductionOperandTypes,
231       copyOperandTypes, copyinOperandTypes, copyinReadonlyOperandTypes,
232       copyoutOperandTypes, copyoutZeroOperandTypes, createOperandTypes,
233       createZeroOperandTypes, noCreateOperandTypes, presentOperandTypes,
234       deviceptrOperandTypes, attachOperandTypes, privateOperandTypes,
235       firstprivateOperandTypes;
236 
237   SmallVector<Type, 8> operandTypes;
238   OpAsmParser::UnresolvedOperand ifCond, selfCond;
239   bool hasIfCond = false, hasSelfCond = false;
240   OptionalParseResult async, numGangs, numWorkers, vectorLength;
241   Type i1Type = builder.getI1Type();
242 
243   // async()?
244   async = parseOptionalOperandAndType(parser, ParallelOp::getAsyncKeyword(),
245                                       result);
246   if (async.hasValue() && failed(*async))
247     return failure();
248 
249   // wait()?
250   if (failed(parseOperandList(parser, ParallelOp::getWaitKeyword(),
251                               waitOperands, waitOperandTypes, result)))
252     return failure();
253 
254   // num_gangs(value)?
255   numGangs = parseOptionalOperandAndType(
256       parser, ParallelOp::getNumGangsKeyword(), result);
257   if (numGangs.hasValue() && failed(*numGangs))
258     return failure();
259 
260   // num_workers(value)?
261   numWorkers = parseOptionalOperandAndType(
262       parser, ParallelOp::getNumWorkersKeyword(), result);
263   if (numWorkers.hasValue() && failed(*numWorkers))
264     return failure();
265 
266   // vector_length(value)?
267   vectorLength = parseOptionalOperandAndType(
268       parser, ParallelOp::getVectorLengthKeyword(), result);
269   if (vectorLength.hasValue() && failed(*vectorLength))
270     return failure();
271 
272   // if()?
273   if (failed(parseOptionalOperand(parser, ParallelOp::getIfKeyword(), ifCond,
274                                   i1Type, hasIfCond, result)))
275     return failure();
276 
277   // self()?
278   if (failed(parseOptionalOperand(parser, ParallelOp::getSelfKeyword(),
279                                   selfCond, i1Type, hasSelfCond, result)))
280     return failure();
281 
282   // reduction()?
283   if (failed(parseOperandList(parser, ParallelOp::getReductionKeyword(),
284                               reductionOperands, reductionOperandTypes,
285                               result)))
286     return failure();
287 
288   // copy()?
289   if (failed(parseOperandList(parser, ParallelOp::getCopyKeyword(),
290                               copyOperands, copyOperandTypes, result)))
291     return failure();
292 
293   // copyin()?
294   if (failed(parseOperandList(parser, ParallelOp::getCopyinKeyword(),
295                               copyinOperands, copyinOperandTypes, result)))
296     return failure();
297 
298   // copyin_readonly()?
299   if (failed(parseOperandList(parser, ParallelOp::getCopyinReadonlyKeyword(),
300                               copyinReadonlyOperands,
301                               copyinReadonlyOperandTypes, result)))
302     return failure();
303 
304   // copyout()?
305   if (failed(parseOperandList(parser, ParallelOp::getCopyoutKeyword(),
306                               copyoutOperands, copyoutOperandTypes, result)))
307     return failure();
308 
309   // copyout_zero()?
310   if (failed(parseOperandList(parser, ParallelOp::getCopyoutZeroKeyword(),
311                               copyoutZeroOperands, copyoutZeroOperandTypes,
312                               result)))
313     return failure();
314 
315   // create()?
316   if (failed(parseOperandList(parser, ParallelOp::getCreateKeyword(),
317                               createOperands, createOperandTypes, result)))
318     return failure();
319 
320   // create_zero()?
321   if (failed(parseOperandList(parser, ParallelOp::getCreateZeroKeyword(),
322                               createZeroOperands, createZeroOperandTypes,
323                               result)))
324     return failure();
325 
326   // no_create()?
327   if (failed(parseOperandList(parser, ParallelOp::getNoCreateKeyword(),
328                               noCreateOperands, noCreateOperandTypes, result)))
329     return failure();
330 
331   // present()?
332   if (failed(parseOperandList(parser, ParallelOp::getPresentKeyword(),
333                               presentOperands, presentOperandTypes, result)))
334     return failure();
335 
336   // deviceptr()?
337   if (failed(parseOperandList(parser, ParallelOp::getDevicePtrKeyword(),
338                               devicePtrOperands, deviceptrOperandTypes,
339                               result)))
340     return failure();
341 
342   // attach()?
343   if (failed(parseOperandList(parser, ParallelOp::getAttachKeyword(),
344                               attachOperands, attachOperandTypes, result)))
345     return failure();
346 
347   // private()?
348   if (failed(parseOperandList(parser, ParallelOp::getPrivateKeyword(),
349                               privateOperands, privateOperandTypes, result)))
350     return failure();
351 
352   // firstprivate()?
353   if (failed(parseOperandList(parser, ParallelOp::getFirstPrivateKeyword(),
354                               firstprivateOperands, firstprivateOperandTypes,
355                               result)))
356     return failure();
357 
358   // Parallel op region
359   if (failed(parseRegions<ParallelOp>(parser, result)))
360     return failure();
361 
362   result.addAttribute(
363       ParallelOp::getOperandSegmentSizeAttr(),
364       builder.getI32VectorAttr(
365           {static_cast<int32_t>(async.hasValue() ? 1 : 0),
366            static_cast<int32_t>(waitOperands.size()),
367            static_cast<int32_t>(numGangs.hasValue() ? 1 : 0),
368            static_cast<int32_t>(numWorkers.hasValue() ? 1 : 0),
369            static_cast<int32_t>(vectorLength.hasValue() ? 1 : 0),
370            static_cast<int32_t>(hasIfCond ? 1 : 0),
371            static_cast<int32_t>(hasSelfCond ? 1 : 0),
372            static_cast<int32_t>(reductionOperands.size()),
373            static_cast<int32_t>(copyOperands.size()),
374            static_cast<int32_t>(copyinOperands.size()),
375            static_cast<int32_t>(copyinReadonlyOperands.size()),
376            static_cast<int32_t>(copyoutOperands.size()),
377            static_cast<int32_t>(copyoutZeroOperands.size()),
378            static_cast<int32_t>(createOperands.size()),
379            static_cast<int32_t>(createZeroOperands.size()),
380            static_cast<int32_t>(noCreateOperands.size()),
381            static_cast<int32_t>(presentOperands.size()),
382            static_cast<int32_t>(devicePtrOperands.size()),
383            static_cast<int32_t>(attachOperands.size()),
384            static_cast<int32_t>(privateOperands.size()),
385            static_cast<int32_t>(firstprivateOperands.size())}));
386 
387   // Additional attributes
388   if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
389     return failure();
390 
391   return success();
392 }
393 
print(OpAsmPrinter & printer)394 void ParallelOp::print(OpAsmPrinter &printer) {
395   // async()?
396   if (Value async = this->async())
397     printer << " " << ParallelOp::getAsyncKeyword() << "(" << async << ": "
398             << async.getType() << ")";
399 
400   // wait()?
401   printOperandList(waitOperands(), ParallelOp::getWaitKeyword(), printer);
402 
403   // num_gangs()?
404   if (Value numGangs = this->numGangs())
405     printer << " " << ParallelOp::getNumGangsKeyword() << "(" << numGangs
406             << ": " << numGangs.getType() << ")";
407 
408   // num_workers()?
409   if (Value numWorkers = this->numWorkers())
410     printer << " " << ParallelOp::getNumWorkersKeyword() << "(" << numWorkers
411             << ": " << numWorkers.getType() << ")";
412 
413   // vector_length()?
414   if (Value vectorLength = this->vectorLength())
415     printer << " " << ParallelOp::getVectorLengthKeyword() << "("
416             << vectorLength << ": " << vectorLength.getType() << ")";
417 
418   // if()?
419   if (Value ifCond = this->ifCond())
420     printer << " " << ParallelOp::getIfKeyword() << "(" << ifCond << ")";
421 
422   // self()?
423   if (Value selfCond = this->selfCond())
424     printer << " " << ParallelOp::getSelfKeyword() << "(" << selfCond << ")";
425 
426   // reduction()?
427   printOperandList(reductionOperands(), ParallelOp::getReductionKeyword(),
428                    printer);
429 
430   // copy()?
431   printOperandList(copyOperands(), ParallelOp::getCopyKeyword(), printer);
432 
433   // copyin()?
434   printOperandList(copyinOperands(), ParallelOp::getCopyinKeyword(), printer);
435 
436   // copyin_readonly()?
437   printOperandList(copyinReadonlyOperands(),
438                    ParallelOp::getCopyinReadonlyKeyword(), printer);
439 
440   // copyout()?
441   printOperandList(copyoutOperands(), ParallelOp::getCopyoutKeyword(), printer);
442 
443   // copyout_zero()?
444   printOperandList(copyoutZeroOperands(), ParallelOp::getCopyoutZeroKeyword(),
445                    printer);
446 
447   // create()?
448   printOperandList(createOperands(), ParallelOp::getCreateKeyword(), printer);
449 
450   // create_zero()?
451   printOperandList(createZeroOperands(), ParallelOp::getCreateZeroKeyword(),
452                    printer);
453 
454   // no_create()?
455   printOperandList(noCreateOperands(), ParallelOp::getNoCreateKeyword(),
456                    printer);
457 
458   // present()?
459   printOperandList(presentOperands(), ParallelOp::getPresentKeyword(), printer);
460 
461   // deviceptr()?
462   printOperandList(devicePtrOperands(), ParallelOp::getDevicePtrKeyword(),
463                    printer);
464 
465   // attach()?
466   printOperandList(attachOperands(), ParallelOp::getAttachKeyword(), printer);
467 
468   // private()?
469   printOperandList(gangPrivateOperands(), ParallelOp::getPrivateKeyword(),
470                    printer);
471 
472   // firstprivate()?
473   printOperandList(gangFirstPrivateOperands(),
474                    ParallelOp::getFirstPrivateKeyword(), printer);
475 
476   printer << ' ';
477   printer.printRegion(region(),
478                       /*printEntryBlockArgs=*/false,
479                       /*printBlockTerminators=*/true);
480   printer.printOptionalAttrDictWithKeyword(
481       (*this)->getAttrs(), ParallelOp::getOperandSegmentSizeAttr());
482 }
483 
getNumDataOperands()484 unsigned ParallelOp::getNumDataOperands() {
485   return reductionOperands().size() + copyOperands().size() +
486          copyinOperands().size() + copyinReadonlyOperands().size() +
487          copyoutOperands().size() + copyoutZeroOperands().size() +
488          createOperands().size() + createZeroOperands().size() +
489          noCreateOperands().size() + presentOperands().size() +
490          devicePtrOperands().size() + attachOperands().size() +
491          gangPrivateOperands().size() + gangFirstPrivateOperands().size();
492 }
493 
getDataOperand(unsigned i)494 Value ParallelOp::getDataOperand(unsigned i) {
495   unsigned numOptional = async() ? 1 : 0;
496   numOptional += numGangs() ? 1 : 0;
497   numOptional += numWorkers() ? 1 : 0;
498   numOptional += vectorLength() ? 1 : 0;
499   numOptional += ifCond() ? 1 : 0;
500   numOptional += selfCond() ? 1 : 0;
501   return getOperand(waitOperands().size() + numOptional + i);
502 }
503 
504 //===----------------------------------------------------------------------===//
505 // LoopOp
506 //===----------------------------------------------------------------------===//
507 
508 /// Parse acc.loop operation
509 /// operation := `acc.loop`
510 ///              (`gang` ( `(` (`num=` value)? (`,` `static=` value `)`)? )? )?
511 ///              (`vector` ( `(` value `)` )? )? (`worker` (`(` value `)`)? )?
512 ///              (`vector_length` `(` value `)`)?
513 ///              (`tile` `(` value-list `)`)?
514 ///              (`private` `(` value-list `)`)?
515 ///              (`reduction` `(` value-list `)`)?
516 ///              region attr-dict?
parse(OpAsmParser & parser,OperationState & result)517 ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) {
518   Builder &builder = parser.getBuilder();
519   unsigned executionMapping = OpenACCExecMapping::NONE;
520   SmallVector<Type, 8> operandTypes;
521   SmallVector<OpAsmParser::UnresolvedOperand, 8> privateOperands,
522       reductionOperands;
523   SmallVector<OpAsmParser::UnresolvedOperand, 8> tileOperands;
524   OptionalParseResult gangNum, gangStatic, worker, vector;
525 
526   // gang?
527   if (succeeded(parser.parseOptionalKeyword(LoopOp::getGangKeyword())))
528     executionMapping |= OpenACCExecMapping::GANG;
529 
530   // optional gang operand
531   if (succeeded(parser.parseOptionalLParen())) {
532     gangNum = parserOptionalOperandAndTypeWithPrefix(
533         parser, result, LoopOp::getGangNumKeyword());
534     if (gangNum.hasValue() && failed(*gangNum))
535       return failure();
536     // FIXME: Comma should require subsequent operands.
537     (void)parser.parseOptionalComma();
538     gangStatic = parserOptionalOperandAndTypeWithPrefix(
539         parser, result, LoopOp::getGangStaticKeyword());
540     if (gangStatic.hasValue() && failed(*gangStatic))
541       return failure();
542     // FIXME: Why allow optional last commas?
543     (void)parser.parseOptionalComma();
544     if (failed(parser.parseRParen()))
545       return failure();
546   }
547 
548   // worker?
549   if (succeeded(parser.parseOptionalKeyword(LoopOp::getWorkerKeyword())))
550     executionMapping |= OpenACCExecMapping::WORKER;
551 
552   // optional worker operand
553   worker = parseOptionalOperandAndType(parser, result);
554   if (worker.hasValue() && failed(*worker))
555     return failure();
556 
557   // vector?
558   if (succeeded(parser.parseOptionalKeyword(LoopOp::getVectorKeyword())))
559     executionMapping |= OpenACCExecMapping::VECTOR;
560 
561   // optional vector operand
562   vector = parseOptionalOperandAndType(parser, result);
563   if (vector.hasValue() && failed(*vector))
564     return failure();
565 
566   // tile()?
567   if (failed(parseOperandList(parser, LoopOp::getTileKeyword(), tileOperands,
568                               operandTypes, result)))
569     return failure();
570 
571   // private()?
572   if (failed(parseOperandList(parser, LoopOp::getPrivateKeyword(),
573                               privateOperands, operandTypes, result)))
574     return failure();
575 
576   // reduction()?
577   if (failed(parseOperandList(parser, LoopOp::getReductionKeyword(),
578                               reductionOperands, operandTypes, result)))
579     return failure();
580 
581   if (executionMapping != acc::OpenACCExecMapping::NONE)
582     result.addAttribute(LoopOp::getExecutionMappingAttrName(),
583                         builder.getI64IntegerAttr(executionMapping));
584 
585   // Parse optional results in case there is a reduce.
586   if (parser.parseOptionalArrowTypeList(result.types))
587     return failure();
588 
589   if (failed(parseRegions<LoopOp>(parser, result)))
590     return failure();
591 
592   result.addAttribute(LoopOp::getOperandSegmentSizeAttr(),
593                       builder.getI32VectorAttr(
594                           {static_cast<int32_t>(gangNum.hasValue() ? 1 : 0),
595                            static_cast<int32_t>(gangStatic.hasValue() ? 1 : 0),
596                            static_cast<int32_t>(worker.hasValue() ? 1 : 0),
597                            static_cast<int32_t>(vector.hasValue() ? 1 : 0),
598                            static_cast<int32_t>(tileOperands.size()),
599                            static_cast<int32_t>(privateOperands.size()),
600                            static_cast<int32_t>(reductionOperands.size())}));
601 
602   if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
603     return failure();
604 
605   return success();
606 }
607 
print(OpAsmPrinter & printer)608 void LoopOp::print(OpAsmPrinter &printer) {
609   unsigned execMapping = exec_mapping();
610   if (execMapping & OpenACCExecMapping::GANG) {
611     printer << " " << LoopOp::getGangKeyword();
612     Value gangNum = this->gangNum();
613     Value gangStatic = this->gangStatic();
614 
615     // Print optional gang operands
616     if (gangNum || gangStatic) {
617       printer << "(";
618       if (gangNum) {
619         printer << LoopOp::getGangNumKeyword() << "=" << gangNum << ": "
620                 << gangNum.getType();
621         if (gangStatic)
622           printer << ", ";
623       }
624       if (gangStatic)
625         printer << LoopOp::getGangStaticKeyword() << "=" << gangStatic << ": "
626                 << gangStatic.getType();
627       printer << ")";
628     }
629   }
630 
631   if (execMapping & OpenACCExecMapping::WORKER) {
632     printer << " " << LoopOp::getWorkerKeyword();
633 
634     // Print optional worker operand if present
635     if (Value workerNum = this->workerNum())
636       printer << "(" << workerNum << ": " << workerNum.getType() << ")";
637   }
638 
639   if (execMapping & OpenACCExecMapping::VECTOR) {
640     printer << " " << LoopOp::getVectorKeyword();
641 
642     // Print optional vector operand if present
643     if (Value vectorLength = this->vectorLength())
644       printer << "(" << vectorLength << ": " << vectorLength.getType() << ")";
645   }
646 
647   // tile()?
648   printOperandList(tileOperands(), LoopOp::getTileKeyword(), printer);
649 
650   // private()?
651   printOperandList(privateOperands(), LoopOp::getPrivateKeyword(), printer);
652 
653   // reduction()?
654   printOperandList(reductionOperands(), LoopOp::getReductionKeyword(), printer);
655 
656   if (getNumResults() > 0)
657     printer << " -> (" << getResultTypes() << ")";
658 
659   printer << ' ';
660   printer.printRegion(region(),
661                       /*printEntryBlockArgs=*/false,
662                       /*printBlockTerminators=*/true);
663 
664   printer.printOptionalAttrDictWithKeyword(
665       (*this)->getAttrs(), {LoopOp::getExecutionMappingAttrName(),
666                             LoopOp::getOperandSegmentSizeAttr()});
667 }
668 
verify()669 LogicalResult acc::LoopOp::verify() {
670   // auto, independent and seq attribute are mutually exclusive.
671   if ((auto_() && (independent() || seq())) || (independent() && seq())) {
672     return emitError("only one of " + acc::LoopOp::getAutoAttrName() + ", " +
673                      acc::LoopOp::getIndependentAttrName() + ", " +
674                      acc::LoopOp::getSeqAttrName() +
675                      " can be present at the same time");
676   }
677 
678   // Gang, worker and vector are incompatible with seq.
679   if (seq() && exec_mapping() != OpenACCExecMapping::NONE)
680     return emitError("gang, worker or vector cannot appear with the seq attr");
681 
682   // Check non-empty body().
683   if (region().empty())
684     return emitError("expected non-empty body.");
685 
686   return success();
687 }
688 
689 //===----------------------------------------------------------------------===//
690 // DataOp
691 //===----------------------------------------------------------------------===//
692 
verify()693 LogicalResult acc::DataOp::verify() {
694   // 2.6.5. Data Construct restriction
695   // At least one copy, copyin, copyout, create, no_create, present, deviceptr,
696   // attach, or default clause must appear on a data construct.
697   if (getOperands().empty() && !defaultAttr())
698     return emitError("at least one operand or the default attribute "
699                      "must appear on the data operation");
700   return success();
701 }
702 
getNumDataOperands()703 unsigned DataOp::getNumDataOperands() {
704   return copyOperands().size() + copyinOperands().size() +
705          copyinReadonlyOperands().size() + copyoutOperands().size() +
706          copyoutZeroOperands().size() + createOperands().size() +
707          createZeroOperands().size() + noCreateOperands().size() +
708          presentOperands().size() + deviceptrOperands().size() +
709          attachOperands().size();
710 }
711 
getDataOperand(unsigned i)712 Value DataOp::getDataOperand(unsigned i) {
713   unsigned numOptional = ifCond() ? 1 : 0;
714   return getOperand(numOptional + i);
715 }
716 
717 //===----------------------------------------------------------------------===//
718 // ExitDataOp
719 //===----------------------------------------------------------------------===//
720 
verify()721 LogicalResult acc::ExitDataOp::verify() {
722   // 2.6.6. Data Exit Directive restriction
723   // At least one copyout, delete, or detach clause must appear on an exit data
724   // directive.
725   if (copyoutOperands().empty() && deleteOperands().empty() &&
726       detachOperands().empty())
727     return emitError(
728         "at least one operand in copyout, delete or detach must appear on the "
729         "exit data operation");
730 
731   // The async attribute represent the async clause without value. Therefore the
732   // attribute and operand cannot appear at the same time.
733   if (asyncOperand() && async())
734     return emitError("async attribute cannot appear with asyncOperand");
735 
736   // The wait attribute represent the wait clause without values. Therefore the
737   // attribute and operands cannot appear at the same time.
738   if (!waitOperands().empty() && wait())
739     return emitError("wait attribute cannot appear with waitOperands");
740 
741   if (waitDevnum() && waitOperands().empty())
742     return emitError("wait_devnum cannot appear without waitOperands");
743 
744   return success();
745 }
746 
getNumDataOperands()747 unsigned ExitDataOp::getNumDataOperands() {
748   return copyoutOperands().size() + deleteOperands().size() +
749          detachOperands().size();
750 }
751 
getDataOperand(unsigned i)752 Value ExitDataOp::getDataOperand(unsigned i) {
753   unsigned numOptional = ifCond() ? 1 : 0;
754   numOptional += asyncOperand() ? 1 : 0;
755   numOptional += waitDevnum() ? 1 : 0;
756   return getOperand(waitOperands().size() + numOptional + i);
757 }
758 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)759 void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
760                                              MLIRContext *context) {
761   results.add<RemoveConstantIfCondition<ExitDataOp>>(context);
762 }
763 
764 //===----------------------------------------------------------------------===//
765 // EnterDataOp
766 //===----------------------------------------------------------------------===//
767 
verify()768 LogicalResult acc::EnterDataOp::verify() {
769   // 2.6.6. Data Enter Directive restriction
770   // At least one copyin, create, or attach clause must appear on an enter data
771   // directive.
772   if (copyinOperands().empty() && createOperands().empty() &&
773       createZeroOperands().empty() && attachOperands().empty())
774     return emitError(
775         "at least one operand in copyin, create, "
776         "create_zero or attach must appear on the enter data operation");
777 
778   // The async attribute represent the async clause without value. Therefore the
779   // attribute and operand cannot appear at the same time.
780   if (asyncOperand() && async())
781     return emitError("async attribute cannot appear with asyncOperand");
782 
783   // The wait attribute represent the wait clause without values. Therefore the
784   // attribute and operands cannot appear at the same time.
785   if (!waitOperands().empty() && wait())
786     return emitError("wait attribute cannot appear with waitOperands");
787 
788   if (waitDevnum() && waitOperands().empty())
789     return emitError("wait_devnum cannot appear without waitOperands");
790 
791   return success();
792 }
793 
getNumDataOperands()794 unsigned EnterDataOp::getNumDataOperands() {
795   return copyinOperands().size() + createOperands().size() +
796          createZeroOperands().size() + attachOperands().size();
797 }
798 
getDataOperand(unsigned i)799 Value EnterDataOp::getDataOperand(unsigned i) {
800   unsigned numOptional = ifCond() ? 1 : 0;
801   numOptional += asyncOperand() ? 1 : 0;
802   numOptional += waitDevnum() ? 1 : 0;
803   return getOperand(waitOperands().size() + numOptional + i);
804 }
805 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)806 void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
807                                               MLIRContext *context) {
808   results.add<RemoveConstantIfCondition<EnterDataOp>>(context);
809 }
810 
811 //===----------------------------------------------------------------------===//
812 // InitOp
813 //===----------------------------------------------------------------------===//
814 
verify()815 LogicalResult acc::InitOp::verify() {
816   Operation *currOp = *this;
817   while ((currOp = currOp->getParentOp()))
818     if (isComputeOperation(currOp))
819       return emitOpError("cannot be nested in a compute operation");
820   return success();
821 }
822 
823 //===----------------------------------------------------------------------===//
824 // ShutdownOp
825 //===----------------------------------------------------------------------===//
826 
verify()827 LogicalResult acc::ShutdownOp::verify() {
828   Operation *currOp = *this;
829   while ((currOp = currOp->getParentOp()))
830     if (isComputeOperation(currOp))
831       return emitOpError("cannot be nested in a compute operation");
832   return success();
833 }
834 
835 //===----------------------------------------------------------------------===//
836 // UpdateOp
837 //===----------------------------------------------------------------------===//
838 
verify()839 LogicalResult acc::UpdateOp::verify() {
840   // At least one of host or device should have a value.
841   if (hostOperands().empty() && deviceOperands().empty())
842     return emitError(
843         "at least one value must be present in hostOperands or deviceOperands");
844 
845   // The async attribute represent the async clause without value. Therefore the
846   // attribute and operand cannot appear at the same time.
847   if (asyncOperand() && async())
848     return emitError("async attribute cannot appear with asyncOperand");
849 
850   // The wait attribute represent the wait clause without values. Therefore the
851   // attribute and operands cannot appear at the same time.
852   if (!waitOperands().empty() && wait())
853     return emitError("wait attribute cannot appear with waitOperands");
854 
855   if (waitDevnum() && waitOperands().empty())
856     return emitError("wait_devnum cannot appear without waitOperands");
857 
858   return success();
859 }
860 
getNumDataOperands()861 unsigned UpdateOp::getNumDataOperands() {
862   return hostOperands().size() + deviceOperands().size();
863 }
864 
getDataOperand(unsigned i)865 Value UpdateOp::getDataOperand(unsigned i) {
866   unsigned numOptional = asyncOperand() ? 1 : 0;
867   numOptional += waitDevnum() ? 1 : 0;
868   numOptional += ifCond() ? 1 : 0;
869   return getOperand(waitOperands().size() + deviceTypeOperands().size() +
870                     numOptional + i);
871 }
872 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)873 void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results,
874                                            MLIRContext *context) {
875   results.add<RemoveConstantIfCondition<UpdateOp>>(context);
876 }
877 
878 //===----------------------------------------------------------------------===//
879 // WaitOp
880 //===----------------------------------------------------------------------===//
881 
verify()882 LogicalResult acc::WaitOp::verify() {
883   // The async attribute represent the async clause without value. Therefore the
884   // attribute and operand cannot appear at the same time.
885   if (asyncOperand() && async())
886     return emitError("async attribute cannot appear with asyncOperand");
887 
888   if (waitDevnum() && waitOperands().empty())
889     return emitError("wait_devnum cannot appear without waitOperands");
890 
891   return success();
892 }
893 
894 #define GET_OP_CLASSES
895 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
896 
897 #define GET_ATTRDEF_CLASSES
898 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
899