1 //===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 // =============================================================================
8
9 #include "mlir/Dialect/OpenACC/OpenACC.h"
10 #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "mlir/IR/DialectImplementation.h"
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/OpImplementation.h"
16 #include "mlir/Transforms/DialectConversion.h"
17 #include "llvm/ADT/TypeSwitch.h"
18
19 using namespace mlir;
20 using namespace acc;
21
22 #include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
23
24 //===----------------------------------------------------------------------===//
25 // OpenACC operations
26 //===----------------------------------------------------------------------===//
27
initialize()28 void OpenACCDialect::initialize() {
29 addOperations<
30 #define GET_OP_LIST
31 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
32 >();
33 addAttributes<
34 #define GET_ATTRDEF_LIST
35 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
36 >();
37 }
38
39 template <typename StructureOp>
parseRegions(OpAsmParser & parser,OperationState & state,unsigned nRegions=1)40 static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
41 unsigned nRegions = 1) {
42
43 SmallVector<Region *, 2> regions;
44 for (unsigned i = 0; i < nRegions; ++i)
45 regions.push_back(state.addRegion());
46
47 for (Region *region : regions) {
48 if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{}))
49 return failure();
50 }
51
52 return success();
53 }
54
55 static ParseResult
parseOperandList(OpAsmParser & parser,StringRef keyword,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & args,SmallVectorImpl<Type> & argTypes,OperationState & result)56 parseOperandList(OpAsmParser &parser, StringRef keyword,
57 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &args,
58 SmallVectorImpl<Type> &argTypes, OperationState &result) {
59 if (failed(parser.parseOptionalKeyword(keyword)))
60 return success();
61
62 if (failed(parser.parseLParen()))
63 return failure();
64
65 // Exit early if the list is empty.
66 if (succeeded(parser.parseOptionalRParen()))
67 return success();
68
69 if (failed(parser.parseCommaSeparatedList([&]() {
70 OpAsmParser::UnresolvedOperand arg;
71 Type type;
72
73 if (parser.parseOperand(arg, /*allowResultNumber=*/false) ||
74 parser.parseColonType(type))
75 return failure();
76
77 args.push_back(arg);
78 argTypes.push_back(type);
79 return success();
80 })) ||
81 failed(parser.parseRParen()))
82 return failure();
83
84 return parser.resolveOperands(args, argTypes, parser.getCurrentLocation(),
85 result.operands);
86 }
87
printOperandList(Operation::operand_range operands,StringRef listName,OpAsmPrinter & printer)88 static void printOperandList(Operation::operand_range operands,
89 StringRef listName, OpAsmPrinter &printer) {
90
91 if (!operands.empty()) {
92 printer << " " << listName << "(";
93 llvm::interleaveComma(operands, printer, [&](Value op) {
94 printer << op << ": " << op.getType();
95 });
96 printer << ")";
97 }
98 }
99
parseOptionalOperand(OpAsmParser & parser,StringRef keyword,OpAsmParser::UnresolvedOperand & operand,Type type,bool & hasOptional,OperationState & result)100 static ParseResult parseOptionalOperand(OpAsmParser &parser, StringRef keyword,
101 OpAsmParser::UnresolvedOperand &operand,
102 Type type, bool &hasOptional,
103 OperationState &result) {
104 hasOptional = false;
105 if (succeeded(parser.parseOptionalKeyword(keyword))) {
106 hasOptional = true;
107 if (parser.parseLParen() || parser.parseOperand(operand) ||
108 parser.resolveOperand(operand, type, result.operands) ||
109 parser.parseRParen())
110 return failure();
111 }
112 return success();
113 }
114
parseOperandAndType(OpAsmParser & parser,OperationState & result)115 static ParseResult parseOperandAndType(OpAsmParser &parser,
116 OperationState &result) {
117 OpAsmParser::UnresolvedOperand operand;
118 Type type;
119 if (parser.parseOperand(operand) || parser.parseColonType(type) ||
120 parser.resolveOperand(operand, type, result.operands))
121 return failure();
122 return success();
123 }
124
125 /// Parse optional operand and its type wrapped in parenthesis prefixed with
126 /// a keyword.
127 /// Example:
128 /// keyword `(` %vectorLength: i64 `)`
parseOptionalOperandAndType(OpAsmParser & parser,StringRef keyword,OperationState & result)129 static OptionalParseResult parseOptionalOperandAndType(OpAsmParser &parser,
130 StringRef keyword,
131 OperationState &result) {
132 OpAsmParser::UnresolvedOperand operand;
133 if (succeeded(parser.parseOptionalKeyword(keyword))) {
134 return failure(parser.parseLParen() ||
135 parseOperandAndType(parser, result) || parser.parseRParen());
136 }
137 return llvm::None;
138 }
139
140 /// Parse optional operand and its type wrapped in parenthesis.
141 /// Example:
142 /// `(` %vectorLength: i64 `)`
parseOptionalOperandAndType(OpAsmParser & parser,OperationState & result)143 static OptionalParseResult parseOptionalOperandAndType(OpAsmParser &parser,
144 OperationState &result) {
145 if (succeeded(parser.parseOptionalLParen())) {
146 return failure(parseOperandAndType(parser, result) || parser.parseRParen());
147 }
148 return llvm::None;
149 }
150
151 /// Parse optional operand with its type prefixed with prefixKeyword `=`.
152 /// Example:
153 /// num=%gangNum: i32
parserOptionalOperandAndTypeWithPrefix(OpAsmParser & parser,OperationState & result,StringRef prefixKeyword)154 static OptionalParseResult parserOptionalOperandAndTypeWithPrefix(
155 OpAsmParser &parser, OperationState &result, StringRef prefixKeyword) {
156 if (succeeded(parser.parseOptionalKeyword(prefixKeyword))) {
157 if (parser.parseEqual() || parseOperandAndType(parser, result))
158 return failure();
159 return success();
160 }
161 return llvm::None;
162 }
163
isComputeOperation(Operation * op)164 static bool isComputeOperation(Operation *op) {
165 return isa<acc::ParallelOp>(op) || isa<acc::LoopOp>(op);
166 }
167
168 namespace {
169 /// Pattern to remove operation without region that have constant false `ifCond`
170 /// and remove the condition from the operation if the `ifCond` is a true
171 /// constant.
172 template <typename OpTy>
173 struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
174 using OpRewritePattern<OpTy>::OpRewritePattern;
175
matchAndRewrite__anon648eb3690311::RemoveConstantIfCondition176 LogicalResult matchAndRewrite(OpTy op,
177 PatternRewriter &rewriter) const override {
178 // Early return if there is no condition.
179 Value ifCond = op.ifCond();
180 if (!ifCond)
181 return success();
182
183 IntegerAttr constAttr;
184 if (matchPattern(ifCond, m_Constant(&constAttr))) {
185 if (constAttr.getInt())
186 rewriter.updateRootInPlace(op, [&]() { op.ifCondMutable().erase(0); });
187 else
188 rewriter.eraseOp(op);
189 }
190
191 return success();
192 }
193 };
194 } // namespace
195
196 //===----------------------------------------------------------------------===//
197 // ParallelOp
198 //===----------------------------------------------------------------------===//
199
200 /// Parse acc.parallel operation
201 /// operation := `acc.parallel` `async` `(` index `)`?
202 /// `wait` `(` index-list `)`?
203 /// `num_gangs` `(` value `)`?
204 /// `num_workers` `(` value `)`?
205 /// `vector_length` `(` value `)`?
206 /// `if` `(` value `)`?
207 /// `self` `(` value `)`?
208 /// `reduction` `(` value-list `)`?
209 /// `copy` `(` value-list `)`?
210 /// `copyin` `(` value-list `)`?
211 /// `copyin_readonly` `(` value-list `)`?
212 /// `copyout` `(` value-list `)`?
213 /// `copyout_zero` `(` value-list `)`?
214 /// `create` `(` value-list `)`?
215 /// `create_zero` `(` value-list `)`?
216 /// `no_create` `(` value-list `)`?
217 /// `present` `(` value-list `)`?
218 /// `deviceptr` `(` value-list `)`?
219 /// `attach` `(` value-list `)`?
220 /// `private` `(` value-list `)`?
221 /// `firstprivate` `(` value-list `)`?
222 /// region attr-dict?
parse(OpAsmParser & parser,OperationState & result)223 ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
224 Builder &builder = parser.getBuilder();
225 SmallVector<OpAsmParser::UnresolvedOperand, 8> privateOperands,
226 firstprivateOperands, copyOperands, copyinOperands,
227 copyinReadonlyOperands, copyoutOperands, copyoutZeroOperands,
228 createOperands, createZeroOperands, noCreateOperands, presentOperands,
229 devicePtrOperands, attachOperands, waitOperands, reductionOperands;
230 SmallVector<Type, 8> waitOperandTypes, reductionOperandTypes,
231 copyOperandTypes, copyinOperandTypes, copyinReadonlyOperandTypes,
232 copyoutOperandTypes, copyoutZeroOperandTypes, createOperandTypes,
233 createZeroOperandTypes, noCreateOperandTypes, presentOperandTypes,
234 deviceptrOperandTypes, attachOperandTypes, privateOperandTypes,
235 firstprivateOperandTypes;
236
237 SmallVector<Type, 8> operandTypes;
238 OpAsmParser::UnresolvedOperand ifCond, selfCond;
239 bool hasIfCond = false, hasSelfCond = false;
240 OptionalParseResult async, numGangs, numWorkers, vectorLength;
241 Type i1Type = builder.getI1Type();
242
243 // async()?
244 async = parseOptionalOperandAndType(parser, ParallelOp::getAsyncKeyword(),
245 result);
246 if (async.hasValue() && failed(*async))
247 return failure();
248
249 // wait()?
250 if (failed(parseOperandList(parser, ParallelOp::getWaitKeyword(),
251 waitOperands, waitOperandTypes, result)))
252 return failure();
253
254 // num_gangs(value)?
255 numGangs = parseOptionalOperandAndType(
256 parser, ParallelOp::getNumGangsKeyword(), result);
257 if (numGangs.hasValue() && failed(*numGangs))
258 return failure();
259
260 // num_workers(value)?
261 numWorkers = parseOptionalOperandAndType(
262 parser, ParallelOp::getNumWorkersKeyword(), result);
263 if (numWorkers.hasValue() && failed(*numWorkers))
264 return failure();
265
266 // vector_length(value)?
267 vectorLength = parseOptionalOperandAndType(
268 parser, ParallelOp::getVectorLengthKeyword(), result);
269 if (vectorLength.hasValue() && failed(*vectorLength))
270 return failure();
271
272 // if()?
273 if (failed(parseOptionalOperand(parser, ParallelOp::getIfKeyword(), ifCond,
274 i1Type, hasIfCond, result)))
275 return failure();
276
277 // self()?
278 if (failed(parseOptionalOperand(parser, ParallelOp::getSelfKeyword(),
279 selfCond, i1Type, hasSelfCond, result)))
280 return failure();
281
282 // reduction()?
283 if (failed(parseOperandList(parser, ParallelOp::getReductionKeyword(),
284 reductionOperands, reductionOperandTypes,
285 result)))
286 return failure();
287
288 // copy()?
289 if (failed(parseOperandList(parser, ParallelOp::getCopyKeyword(),
290 copyOperands, copyOperandTypes, result)))
291 return failure();
292
293 // copyin()?
294 if (failed(parseOperandList(parser, ParallelOp::getCopyinKeyword(),
295 copyinOperands, copyinOperandTypes, result)))
296 return failure();
297
298 // copyin_readonly()?
299 if (failed(parseOperandList(parser, ParallelOp::getCopyinReadonlyKeyword(),
300 copyinReadonlyOperands,
301 copyinReadonlyOperandTypes, result)))
302 return failure();
303
304 // copyout()?
305 if (failed(parseOperandList(parser, ParallelOp::getCopyoutKeyword(),
306 copyoutOperands, copyoutOperandTypes, result)))
307 return failure();
308
309 // copyout_zero()?
310 if (failed(parseOperandList(parser, ParallelOp::getCopyoutZeroKeyword(),
311 copyoutZeroOperands, copyoutZeroOperandTypes,
312 result)))
313 return failure();
314
315 // create()?
316 if (failed(parseOperandList(parser, ParallelOp::getCreateKeyword(),
317 createOperands, createOperandTypes, result)))
318 return failure();
319
320 // create_zero()?
321 if (failed(parseOperandList(parser, ParallelOp::getCreateZeroKeyword(),
322 createZeroOperands, createZeroOperandTypes,
323 result)))
324 return failure();
325
326 // no_create()?
327 if (failed(parseOperandList(parser, ParallelOp::getNoCreateKeyword(),
328 noCreateOperands, noCreateOperandTypes, result)))
329 return failure();
330
331 // present()?
332 if (failed(parseOperandList(parser, ParallelOp::getPresentKeyword(),
333 presentOperands, presentOperandTypes, result)))
334 return failure();
335
336 // deviceptr()?
337 if (failed(parseOperandList(parser, ParallelOp::getDevicePtrKeyword(),
338 devicePtrOperands, deviceptrOperandTypes,
339 result)))
340 return failure();
341
342 // attach()?
343 if (failed(parseOperandList(parser, ParallelOp::getAttachKeyword(),
344 attachOperands, attachOperandTypes, result)))
345 return failure();
346
347 // private()?
348 if (failed(parseOperandList(parser, ParallelOp::getPrivateKeyword(),
349 privateOperands, privateOperandTypes, result)))
350 return failure();
351
352 // firstprivate()?
353 if (failed(parseOperandList(parser, ParallelOp::getFirstPrivateKeyword(),
354 firstprivateOperands, firstprivateOperandTypes,
355 result)))
356 return failure();
357
358 // Parallel op region
359 if (failed(parseRegions<ParallelOp>(parser, result)))
360 return failure();
361
362 result.addAttribute(
363 ParallelOp::getOperandSegmentSizeAttr(),
364 builder.getI32VectorAttr(
365 {static_cast<int32_t>(async.hasValue() ? 1 : 0),
366 static_cast<int32_t>(waitOperands.size()),
367 static_cast<int32_t>(numGangs.hasValue() ? 1 : 0),
368 static_cast<int32_t>(numWorkers.hasValue() ? 1 : 0),
369 static_cast<int32_t>(vectorLength.hasValue() ? 1 : 0),
370 static_cast<int32_t>(hasIfCond ? 1 : 0),
371 static_cast<int32_t>(hasSelfCond ? 1 : 0),
372 static_cast<int32_t>(reductionOperands.size()),
373 static_cast<int32_t>(copyOperands.size()),
374 static_cast<int32_t>(copyinOperands.size()),
375 static_cast<int32_t>(copyinReadonlyOperands.size()),
376 static_cast<int32_t>(copyoutOperands.size()),
377 static_cast<int32_t>(copyoutZeroOperands.size()),
378 static_cast<int32_t>(createOperands.size()),
379 static_cast<int32_t>(createZeroOperands.size()),
380 static_cast<int32_t>(noCreateOperands.size()),
381 static_cast<int32_t>(presentOperands.size()),
382 static_cast<int32_t>(devicePtrOperands.size()),
383 static_cast<int32_t>(attachOperands.size()),
384 static_cast<int32_t>(privateOperands.size()),
385 static_cast<int32_t>(firstprivateOperands.size())}));
386
387 // Additional attributes
388 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
389 return failure();
390
391 return success();
392 }
393
print(OpAsmPrinter & printer)394 void ParallelOp::print(OpAsmPrinter &printer) {
395 // async()?
396 if (Value async = this->async())
397 printer << " " << ParallelOp::getAsyncKeyword() << "(" << async << ": "
398 << async.getType() << ")";
399
400 // wait()?
401 printOperandList(waitOperands(), ParallelOp::getWaitKeyword(), printer);
402
403 // num_gangs()?
404 if (Value numGangs = this->numGangs())
405 printer << " " << ParallelOp::getNumGangsKeyword() << "(" << numGangs
406 << ": " << numGangs.getType() << ")";
407
408 // num_workers()?
409 if (Value numWorkers = this->numWorkers())
410 printer << " " << ParallelOp::getNumWorkersKeyword() << "(" << numWorkers
411 << ": " << numWorkers.getType() << ")";
412
413 // vector_length()?
414 if (Value vectorLength = this->vectorLength())
415 printer << " " << ParallelOp::getVectorLengthKeyword() << "("
416 << vectorLength << ": " << vectorLength.getType() << ")";
417
418 // if()?
419 if (Value ifCond = this->ifCond())
420 printer << " " << ParallelOp::getIfKeyword() << "(" << ifCond << ")";
421
422 // self()?
423 if (Value selfCond = this->selfCond())
424 printer << " " << ParallelOp::getSelfKeyword() << "(" << selfCond << ")";
425
426 // reduction()?
427 printOperandList(reductionOperands(), ParallelOp::getReductionKeyword(),
428 printer);
429
430 // copy()?
431 printOperandList(copyOperands(), ParallelOp::getCopyKeyword(), printer);
432
433 // copyin()?
434 printOperandList(copyinOperands(), ParallelOp::getCopyinKeyword(), printer);
435
436 // copyin_readonly()?
437 printOperandList(copyinReadonlyOperands(),
438 ParallelOp::getCopyinReadonlyKeyword(), printer);
439
440 // copyout()?
441 printOperandList(copyoutOperands(), ParallelOp::getCopyoutKeyword(), printer);
442
443 // copyout_zero()?
444 printOperandList(copyoutZeroOperands(), ParallelOp::getCopyoutZeroKeyword(),
445 printer);
446
447 // create()?
448 printOperandList(createOperands(), ParallelOp::getCreateKeyword(), printer);
449
450 // create_zero()?
451 printOperandList(createZeroOperands(), ParallelOp::getCreateZeroKeyword(),
452 printer);
453
454 // no_create()?
455 printOperandList(noCreateOperands(), ParallelOp::getNoCreateKeyword(),
456 printer);
457
458 // present()?
459 printOperandList(presentOperands(), ParallelOp::getPresentKeyword(), printer);
460
461 // deviceptr()?
462 printOperandList(devicePtrOperands(), ParallelOp::getDevicePtrKeyword(),
463 printer);
464
465 // attach()?
466 printOperandList(attachOperands(), ParallelOp::getAttachKeyword(), printer);
467
468 // private()?
469 printOperandList(gangPrivateOperands(), ParallelOp::getPrivateKeyword(),
470 printer);
471
472 // firstprivate()?
473 printOperandList(gangFirstPrivateOperands(),
474 ParallelOp::getFirstPrivateKeyword(), printer);
475
476 printer << ' ';
477 printer.printRegion(region(),
478 /*printEntryBlockArgs=*/false,
479 /*printBlockTerminators=*/true);
480 printer.printOptionalAttrDictWithKeyword(
481 (*this)->getAttrs(), ParallelOp::getOperandSegmentSizeAttr());
482 }
483
getNumDataOperands()484 unsigned ParallelOp::getNumDataOperands() {
485 return reductionOperands().size() + copyOperands().size() +
486 copyinOperands().size() + copyinReadonlyOperands().size() +
487 copyoutOperands().size() + copyoutZeroOperands().size() +
488 createOperands().size() + createZeroOperands().size() +
489 noCreateOperands().size() + presentOperands().size() +
490 devicePtrOperands().size() + attachOperands().size() +
491 gangPrivateOperands().size() + gangFirstPrivateOperands().size();
492 }
493
getDataOperand(unsigned i)494 Value ParallelOp::getDataOperand(unsigned i) {
495 unsigned numOptional = async() ? 1 : 0;
496 numOptional += numGangs() ? 1 : 0;
497 numOptional += numWorkers() ? 1 : 0;
498 numOptional += vectorLength() ? 1 : 0;
499 numOptional += ifCond() ? 1 : 0;
500 numOptional += selfCond() ? 1 : 0;
501 return getOperand(waitOperands().size() + numOptional + i);
502 }
503
504 //===----------------------------------------------------------------------===//
505 // LoopOp
506 //===----------------------------------------------------------------------===//
507
508 /// Parse acc.loop operation
509 /// operation := `acc.loop`
510 /// (`gang` ( `(` (`num=` value)? (`,` `static=` value `)`)? )? )?
511 /// (`vector` ( `(` value `)` )? )? (`worker` (`(` value `)`)? )?
512 /// (`vector_length` `(` value `)`)?
513 /// (`tile` `(` value-list `)`)?
514 /// (`private` `(` value-list `)`)?
515 /// (`reduction` `(` value-list `)`)?
516 /// region attr-dict?
parse(OpAsmParser & parser,OperationState & result)517 ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) {
518 Builder &builder = parser.getBuilder();
519 unsigned executionMapping = OpenACCExecMapping::NONE;
520 SmallVector<Type, 8> operandTypes;
521 SmallVector<OpAsmParser::UnresolvedOperand, 8> privateOperands,
522 reductionOperands;
523 SmallVector<OpAsmParser::UnresolvedOperand, 8> tileOperands;
524 OptionalParseResult gangNum, gangStatic, worker, vector;
525
526 // gang?
527 if (succeeded(parser.parseOptionalKeyword(LoopOp::getGangKeyword())))
528 executionMapping |= OpenACCExecMapping::GANG;
529
530 // optional gang operand
531 if (succeeded(parser.parseOptionalLParen())) {
532 gangNum = parserOptionalOperandAndTypeWithPrefix(
533 parser, result, LoopOp::getGangNumKeyword());
534 if (gangNum.hasValue() && failed(*gangNum))
535 return failure();
536 // FIXME: Comma should require subsequent operands.
537 (void)parser.parseOptionalComma();
538 gangStatic = parserOptionalOperandAndTypeWithPrefix(
539 parser, result, LoopOp::getGangStaticKeyword());
540 if (gangStatic.hasValue() && failed(*gangStatic))
541 return failure();
542 // FIXME: Why allow optional last commas?
543 (void)parser.parseOptionalComma();
544 if (failed(parser.parseRParen()))
545 return failure();
546 }
547
548 // worker?
549 if (succeeded(parser.parseOptionalKeyword(LoopOp::getWorkerKeyword())))
550 executionMapping |= OpenACCExecMapping::WORKER;
551
552 // optional worker operand
553 worker = parseOptionalOperandAndType(parser, result);
554 if (worker.hasValue() && failed(*worker))
555 return failure();
556
557 // vector?
558 if (succeeded(parser.parseOptionalKeyword(LoopOp::getVectorKeyword())))
559 executionMapping |= OpenACCExecMapping::VECTOR;
560
561 // optional vector operand
562 vector = parseOptionalOperandAndType(parser, result);
563 if (vector.hasValue() && failed(*vector))
564 return failure();
565
566 // tile()?
567 if (failed(parseOperandList(parser, LoopOp::getTileKeyword(), tileOperands,
568 operandTypes, result)))
569 return failure();
570
571 // private()?
572 if (failed(parseOperandList(parser, LoopOp::getPrivateKeyword(),
573 privateOperands, operandTypes, result)))
574 return failure();
575
576 // reduction()?
577 if (failed(parseOperandList(parser, LoopOp::getReductionKeyword(),
578 reductionOperands, operandTypes, result)))
579 return failure();
580
581 if (executionMapping != acc::OpenACCExecMapping::NONE)
582 result.addAttribute(LoopOp::getExecutionMappingAttrName(),
583 builder.getI64IntegerAttr(executionMapping));
584
585 // Parse optional results in case there is a reduce.
586 if (parser.parseOptionalArrowTypeList(result.types))
587 return failure();
588
589 if (failed(parseRegions<LoopOp>(parser, result)))
590 return failure();
591
592 result.addAttribute(LoopOp::getOperandSegmentSizeAttr(),
593 builder.getI32VectorAttr(
594 {static_cast<int32_t>(gangNum.hasValue() ? 1 : 0),
595 static_cast<int32_t>(gangStatic.hasValue() ? 1 : 0),
596 static_cast<int32_t>(worker.hasValue() ? 1 : 0),
597 static_cast<int32_t>(vector.hasValue() ? 1 : 0),
598 static_cast<int32_t>(tileOperands.size()),
599 static_cast<int32_t>(privateOperands.size()),
600 static_cast<int32_t>(reductionOperands.size())}));
601
602 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
603 return failure();
604
605 return success();
606 }
607
print(OpAsmPrinter & printer)608 void LoopOp::print(OpAsmPrinter &printer) {
609 unsigned execMapping = exec_mapping();
610 if (execMapping & OpenACCExecMapping::GANG) {
611 printer << " " << LoopOp::getGangKeyword();
612 Value gangNum = this->gangNum();
613 Value gangStatic = this->gangStatic();
614
615 // Print optional gang operands
616 if (gangNum || gangStatic) {
617 printer << "(";
618 if (gangNum) {
619 printer << LoopOp::getGangNumKeyword() << "=" << gangNum << ": "
620 << gangNum.getType();
621 if (gangStatic)
622 printer << ", ";
623 }
624 if (gangStatic)
625 printer << LoopOp::getGangStaticKeyword() << "=" << gangStatic << ": "
626 << gangStatic.getType();
627 printer << ")";
628 }
629 }
630
631 if (execMapping & OpenACCExecMapping::WORKER) {
632 printer << " " << LoopOp::getWorkerKeyword();
633
634 // Print optional worker operand if present
635 if (Value workerNum = this->workerNum())
636 printer << "(" << workerNum << ": " << workerNum.getType() << ")";
637 }
638
639 if (execMapping & OpenACCExecMapping::VECTOR) {
640 printer << " " << LoopOp::getVectorKeyword();
641
642 // Print optional vector operand if present
643 if (Value vectorLength = this->vectorLength())
644 printer << "(" << vectorLength << ": " << vectorLength.getType() << ")";
645 }
646
647 // tile()?
648 printOperandList(tileOperands(), LoopOp::getTileKeyword(), printer);
649
650 // private()?
651 printOperandList(privateOperands(), LoopOp::getPrivateKeyword(), printer);
652
653 // reduction()?
654 printOperandList(reductionOperands(), LoopOp::getReductionKeyword(), printer);
655
656 if (getNumResults() > 0)
657 printer << " -> (" << getResultTypes() << ")";
658
659 printer << ' ';
660 printer.printRegion(region(),
661 /*printEntryBlockArgs=*/false,
662 /*printBlockTerminators=*/true);
663
664 printer.printOptionalAttrDictWithKeyword(
665 (*this)->getAttrs(), {LoopOp::getExecutionMappingAttrName(),
666 LoopOp::getOperandSegmentSizeAttr()});
667 }
668
verify()669 LogicalResult acc::LoopOp::verify() {
670 // auto, independent and seq attribute are mutually exclusive.
671 if ((auto_() && (independent() || seq())) || (independent() && seq())) {
672 return emitError("only one of " + acc::LoopOp::getAutoAttrName() + ", " +
673 acc::LoopOp::getIndependentAttrName() + ", " +
674 acc::LoopOp::getSeqAttrName() +
675 " can be present at the same time");
676 }
677
678 // Gang, worker and vector are incompatible with seq.
679 if (seq() && exec_mapping() != OpenACCExecMapping::NONE)
680 return emitError("gang, worker or vector cannot appear with the seq attr");
681
682 // Check non-empty body().
683 if (region().empty())
684 return emitError("expected non-empty body.");
685
686 return success();
687 }
688
689 //===----------------------------------------------------------------------===//
690 // DataOp
691 //===----------------------------------------------------------------------===//
692
verify()693 LogicalResult acc::DataOp::verify() {
694 // 2.6.5. Data Construct restriction
695 // At least one copy, copyin, copyout, create, no_create, present, deviceptr,
696 // attach, or default clause must appear on a data construct.
697 if (getOperands().empty() && !defaultAttr())
698 return emitError("at least one operand or the default attribute "
699 "must appear on the data operation");
700 return success();
701 }
702
getNumDataOperands()703 unsigned DataOp::getNumDataOperands() {
704 return copyOperands().size() + copyinOperands().size() +
705 copyinReadonlyOperands().size() + copyoutOperands().size() +
706 copyoutZeroOperands().size() + createOperands().size() +
707 createZeroOperands().size() + noCreateOperands().size() +
708 presentOperands().size() + deviceptrOperands().size() +
709 attachOperands().size();
710 }
711
getDataOperand(unsigned i)712 Value DataOp::getDataOperand(unsigned i) {
713 unsigned numOptional = ifCond() ? 1 : 0;
714 return getOperand(numOptional + i);
715 }
716
717 //===----------------------------------------------------------------------===//
718 // ExitDataOp
719 //===----------------------------------------------------------------------===//
720
verify()721 LogicalResult acc::ExitDataOp::verify() {
722 // 2.6.6. Data Exit Directive restriction
723 // At least one copyout, delete, or detach clause must appear on an exit data
724 // directive.
725 if (copyoutOperands().empty() && deleteOperands().empty() &&
726 detachOperands().empty())
727 return emitError(
728 "at least one operand in copyout, delete or detach must appear on the "
729 "exit data operation");
730
731 // The async attribute represent the async clause without value. Therefore the
732 // attribute and operand cannot appear at the same time.
733 if (asyncOperand() && async())
734 return emitError("async attribute cannot appear with asyncOperand");
735
736 // The wait attribute represent the wait clause without values. Therefore the
737 // attribute and operands cannot appear at the same time.
738 if (!waitOperands().empty() && wait())
739 return emitError("wait attribute cannot appear with waitOperands");
740
741 if (waitDevnum() && waitOperands().empty())
742 return emitError("wait_devnum cannot appear without waitOperands");
743
744 return success();
745 }
746
getNumDataOperands()747 unsigned ExitDataOp::getNumDataOperands() {
748 return copyoutOperands().size() + deleteOperands().size() +
749 detachOperands().size();
750 }
751
getDataOperand(unsigned i)752 Value ExitDataOp::getDataOperand(unsigned i) {
753 unsigned numOptional = ifCond() ? 1 : 0;
754 numOptional += asyncOperand() ? 1 : 0;
755 numOptional += waitDevnum() ? 1 : 0;
756 return getOperand(waitOperands().size() + numOptional + i);
757 }
758
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)759 void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
760 MLIRContext *context) {
761 results.add<RemoveConstantIfCondition<ExitDataOp>>(context);
762 }
763
764 //===----------------------------------------------------------------------===//
765 // EnterDataOp
766 //===----------------------------------------------------------------------===//
767
verify()768 LogicalResult acc::EnterDataOp::verify() {
769 // 2.6.6. Data Enter Directive restriction
770 // At least one copyin, create, or attach clause must appear on an enter data
771 // directive.
772 if (copyinOperands().empty() && createOperands().empty() &&
773 createZeroOperands().empty() && attachOperands().empty())
774 return emitError(
775 "at least one operand in copyin, create, "
776 "create_zero or attach must appear on the enter data operation");
777
778 // The async attribute represent the async clause without value. Therefore the
779 // attribute and operand cannot appear at the same time.
780 if (asyncOperand() && async())
781 return emitError("async attribute cannot appear with asyncOperand");
782
783 // The wait attribute represent the wait clause without values. Therefore the
784 // attribute and operands cannot appear at the same time.
785 if (!waitOperands().empty() && wait())
786 return emitError("wait attribute cannot appear with waitOperands");
787
788 if (waitDevnum() && waitOperands().empty())
789 return emitError("wait_devnum cannot appear without waitOperands");
790
791 return success();
792 }
793
getNumDataOperands()794 unsigned EnterDataOp::getNumDataOperands() {
795 return copyinOperands().size() + createOperands().size() +
796 createZeroOperands().size() + attachOperands().size();
797 }
798
getDataOperand(unsigned i)799 Value EnterDataOp::getDataOperand(unsigned i) {
800 unsigned numOptional = ifCond() ? 1 : 0;
801 numOptional += asyncOperand() ? 1 : 0;
802 numOptional += waitDevnum() ? 1 : 0;
803 return getOperand(waitOperands().size() + numOptional + i);
804 }
805
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)806 void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
807 MLIRContext *context) {
808 results.add<RemoveConstantIfCondition<EnterDataOp>>(context);
809 }
810
811 //===----------------------------------------------------------------------===//
812 // InitOp
813 //===----------------------------------------------------------------------===//
814
verify()815 LogicalResult acc::InitOp::verify() {
816 Operation *currOp = *this;
817 while ((currOp = currOp->getParentOp()))
818 if (isComputeOperation(currOp))
819 return emitOpError("cannot be nested in a compute operation");
820 return success();
821 }
822
823 //===----------------------------------------------------------------------===//
824 // ShutdownOp
825 //===----------------------------------------------------------------------===//
826
verify()827 LogicalResult acc::ShutdownOp::verify() {
828 Operation *currOp = *this;
829 while ((currOp = currOp->getParentOp()))
830 if (isComputeOperation(currOp))
831 return emitOpError("cannot be nested in a compute operation");
832 return success();
833 }
834
835 //===----------------------------------------------------------------------===//
836 // UpdateOp
837 //===----------------------------------------------------------------------===//
838
verify()839 LogicalResult acc::UpdateOp::verify() {
840 // At least one of host or device should have a value.
841 if (hostOperands().empty() && deviceOperands().empty())
842 return emitError(
843 "at least one value must be present in hostOperands or deviceOperands");
844
845 // The async attribute represent the async clause without value. Therefore the
846 // attribute and operand cannot appear at the same time.
847 if (asyncOperand() && async())
848 return emitError("async attribute cannot appear with asyncOperand");
849
850 // The wait attribute represent the wait clause without values. Therefore the
851 // attribute and operands cannot appear at the same time.
852 if (!waitOperands().empty() && wait())
853 return emitError("wait attribute cannot appear with waitOperands");
854
855 if (waitDevnum() && waitOperands().empty())
856 return emitError("wait_devnum cannot appear without waitOperands");
857
858 return success();
859 }
860
getNumDataOperands()861 unsigned UpdateOp::getNumDataOperands() {
862 return hostOperands().size() + deviceOperands().size();
863 }
864
getDataOperand(unsigned i)865 Value UpdateOp::getDataOperand(unsigned i) {
866 unsigned numOptional = asyncOperand() ? 1 : 0;
867 numOptional += waitDevnum() ? 1 : 0;
868 numOptional += ifCond() ? 1 : 0;
869 return getOperand(waitOperands().size() + deviceTypeOperands().size() +
870 numOptional + i);
871 }
872
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)873 void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results,
874 MLIRContext *context) {
875 results.add<RemoveConstantIfCondition<UpdateOp>>(context);
876 }
877
878 //===----------------------------------------------------------------------===//
879 // WaitOp
880 //===----------------------------------------------------------------------===//
881
verify()882 LogicalResult acc::WaitOp::verify() {
883 // The async attribute represent the async clause without value. Therefore the
884 // attribute and operand cannot appear at the same time.
885 if (asyncOperand() && async())
886 return emitError("async attribute cannot appear with asyncOperand");
887
888 if (waitDevnum() && waitOperands().empty())
889 return emitError("wait_devnum cannot appear without waitOperands");
890
891 return success();
892 }
893
894 #define GET_OP_CLASSES
895 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
896
897 #define GET_ATTRDEF_CLASSES
898 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
899