1 //===-- FIROps.cpp --------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Optimizer/Dialect/FIROps.h"
10 #include "flang/Optimizer/Dialect/FIRAttr.h"
11 #include "flang/Optimizer/Dialect/FIROpsSupport.h"
12 #include "flang/Optimizer/Dialect/FIRType.h"
13 #include "mlir/ADT/TypeSwitch.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/Function.h"
17 #include "mlir/IR/Module.h"
18 #include "mlir/IR/StandardTypes.h"
19 #include "mlir/IR/SymbolTable.h"
20 #include "llvm/ADT/StringSwitch.h"
21 
22 using namespace fir;
23 
24 /// return true if the sequence type is abstract or the record type is malformed
25 /// or contains an abstract sequence type
26 static bool verifyInType(mlir::Type inType,
27                          llvm::SmallVectorImpl<llvm::StringRef> &visited) {
28   if (auto st = inType.dyn_cast<fir::SequenceType>()) {
29     auto shape = st.getShape();
30     if (shape.size() == 0)
31       return true;
32     for (auto ext : shape)
33       if (ext < 0)
34         return true;
35   } else if (auto rt = inType.dyn_cast<fir::RecordType>()) {
36     // don't recurse if we're already visiting this one
37     if (llvm::is_contained(visited, rt.getName()))
38       return false;
39     // keep track of record types currently being visited
40     visited.push_back(rt.getName());
41     for (auto &field : rt.getTypeList())
42       if (verifyInType(field.second, visited))
43         return true;
44     visited.pop_back();
45   } else if (auto rt = inType.dyn_cast<fir::PointerType>()) {
46     return verifyInType(rt.getEleTy(), visited);
47   }
48   return false;
49 }
50 
51 static bool verifyRecordLenParams(mlir::Type inType, unsigned numLenParams) {
52   if (numLenParams > 0) {
53     if (auto rt = inType.dyn_cast<fir::RecordType>())
54       return numLenParams != rt.getNumLenParams();
55     return true;
56   }
57   return false;
58 }
59 
60 //===----------------------------------------------------------------------===//
61 // AllocaOp
62 //===----------------------------------------------------------------------===//
63 
64 mlir::Type fir::AllocaOp::getAllocatedType() {
65   return getType().cast<ReferenceType>().getEleTy();
66 }
67 
68 /// Create a legal memory reference as return type
69 mlir::Type fir::AllocaOp::wrapResultType(mlir::Type intype) {
70   // FIR semantics: memory references to memory references are disallowed
71   if (intype.isa<ReferenceType>())
72     return {};
73   return ReferenceType::get(intype);
74 }
75 
76 mlir::Type fir::AllocaOp::getRefTy(mlir::Type ty) {
77   return ReferenceType::get(ty);
78 }
79 
80 //===----------------------------------------------------------------------===//
81 // AllocMemOp
82 //===----------------------------------------------------------------------===//
83 
84 mlir::Type fir::AllocMemOp::getAllocatedType() {
85   return getType().cast<HeapType>().getEleTy();
86 }
87 
88 mlir::Type fir::AllocMemOp::getRefTy(mlir::Type ty) {
89   return HeapType::get(ty);
90 }
91 
92 /// Create a legal heap reference as return type
93 mlir::Type fir::AllocMemOp::wrapResultType(mlir::Type intype) {
94   // Fortran semantics: C852 an entity cannot be both ALLOCATABLE and POINTER
95   // 8.5.3 note 1 prohibits ALLOCATABLE procedures as well
96   // FIR semantics: one may not allocate a memory reference value
97   if (intype.isa<ReferenceType>() || intype.isa<HeapType>() ||
98       intype.isa<PointerType>() || intype.isa<FunctionType>())
99     return {};
100   return HeapType::get(intype);
101 }
102 
103 //===----------------------------------------------------------------------===//
104 // BoxDimsOp
105 //===----------------------------------------------------------------------===//
106 
107 /// Get the result types packed in a tuple tuple
108 mlir::Type fir::BoxDimsOp::getTupleType() {
109   // note: triple, but 4 is nearest power of 2
110   llvm::SmallVector<mlir::Type, 4> triple{
111       getResult(0).getType(), getResult(1).getType(), getResult(2).getType()};
112   return mlir::TupleType::get(triple, getContext());
113 }
114 
115 //===----------------------------------------------------------------------===//
116 // CallOp
117 //===----------------------------------------------------------------------===//
118 
119 static void printCallOp(mlir::OpAsmPrinter &p, fir::CallOp &op) {
120   auto callee = op.callee();
121   bool isDirect = callee.hasValue();
122   p << op.getOperationName() << ' ';
123   if (isDirect)
124     p << callee.getValue();
125   else
126     p << op.getOperand(0);
127   p << '(' << op.getOperands().drop_front(isDirect ? 0 : 1) << ')';
128   p.printOptionalAttrDict(op.getAttrs(), {fir::CallOp::calleeAttrName()});
129   auto resultTypes{op.getResultTypes()};
130   llvm::SmallVector<Type, 8> argTypes(
131       llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1));
132   p << " : " << FunctionType::get(argTypes, resultTypes, op.getContext());
133 }
134 
135 static mlir::ParseResult parseCallOp(mlir::OpAsmParser &parser,
136                                      mlir::OperationState &result) {
137   llvm::SmallVector<mlir::OpAsmParser::OperandType, 8> operands;
138   if (parser.parseOperandList(operands))
139     return mlir::failure();
140 
141   llvm::SmallVector<mlir::NamedAttribute, 4> attrs;
142   mlir::SymbolRefAttr funcAttr;
143   bool isDirect = operands.empty();
144   if (isDirect)
145     if (parser.parseAttribute(funcAttr, fir::CallOp::calleeAttrName(), attrs))
146       return mlir::failure();
147 
148   Type type;
149   if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::Paren) ||
150       parser.parseOptionalAttrDict(attrs) || parser.parseColon() ||
151       parser.parseType(type))
152     return mlir::failure();
153 
154   auto funcType = type.dyn_cast<mlir::FunctionType>();
155   if (!funcType)
156     return parser.emitError(parser.getNameLoc(), "expected function type");
157   if (isDirect) {
158     if (parser.resolveOperands(operands, funcType.getInputs(),
159                                parser.getNameLoc(), result.operands))
160       return mlir::failure();
161   } else {
162     auto funcArgs =
163         llvm::ArrayRef<mlir::OpAsmParser::OperandType>(operands).drop_front();
164     llvm::SmallVector<mlir::Value, 8> resultArgs(
165         result.operands.begin() + (result.operands.empty() ? 0 : 1),
166         result.operands.end());
167     if (parser.resolveOperand(operands[0], funcType, result.operands) ||
168         parser.resolveOperands(funcArgs, funcType.getInputs(),
169                                parser.getNameLoc(), resultArgs))
170       return mlir::failure();
171   }
172   result.addTypes(funcType.getResults());
173   result.attributes = attrs;
174   return mlir::success();
175 }
176 
177 //===----------------------------------------------------------------------===//
178 // CmpfOp
179 //===----------------------------------------------------------------------===//
180 
181 // Note: getCmpFPredicateNames() is inline static in StandardOps/IR/Ops.cpp
182 mlir::CmpFPredicate fir::CmpfOp::getPredicateByName(llvm::StringRef name) {
183   auto pred = mlir::symbolizeCmpFPredicate(name);
184   assert(pred.hasValue() && "invalid predicate name");
185   return pred.getValue();
186 }
187 
188 void fir::buildCmpFOp(Builder *builder, OperationState &result,
189                       CmpFPredicate predicate, Value lhs, Value rhs) {
190   result.addOperands({lhs, rhs});
191   result.types.push_back(builder->getI1Type());
192   result.addAttribute(
193       CmpfOp::getPredicateAttrName(),
194       builder->getI64IntegerAttr(static_cast<int64_t>(predicate)));
195 }
196 
197 template <typename OPTY>
198 static void printCmpOp(OpAsmPrinter &p, OPTY op) {
199   p << op.getOperationName() << ' ';
200   auto predSym = mlir::symbolizeCmpFPredicate(
201       op.template getAttrOfType<mlir::IntegerAttr>(OPTY::getPredicateAttrName())
202           .getInt());
203   assert(predSym.hasValue() && "invalid symbol value for predicate");
204   p << '"' << mlir::stringifyCmpFPredicate(predSym.getValue()) << '"' << ", ";
205   p.printOperand(op.lhs());
206   p << ", ";
207   p.printOperand(op.rhs());
208   p.printOptionalAttrDict(op.getAttrs(),
209                           /*elidedAttrs=*/{OPTY::getPredicateAttrName()});
210   p << " : " << op.lhs().getType();
211 }
212 
213 static void printCmpfOp(OpAsmPrinter &p, CmpfOp op) { printCmpOp(p, op); }
214 
215 template <typename OPTY>
216 static mlir::ParseResult parseCmpOp(mlir::OpAsmParser &parser,
217                                     mlir::OperationState &result) {
218   llvm::SmallVector<mlir::OpAsmParser::OperandType, 2> ops;
219   llvm::SmallVector<mlir::NamedAttribute, 4> attrs;
220   mlir::Attribute predicateNameAttr;
221   mlir::Type type;
222   if (parser.parseAttribute(predicateNameAttr, OPTY::getPredicateAttrName(),
223                             attrs) ||
224       parser.parseComma() || parser.parseOperandList(ops, 2) ||
225       parser.parseOptionalAttrDict(attrs) || parser.parseColonType(type) ||
226       parser.resolveOperands(ops, type, result.operands))
227     return failure();
228 
229   if (!predicateNameAttr.isa<mlir::StringAttr>())
230     return parser.emitError(parser.getNameLoc(),
231                             "expected string comparison predicate attribute");
232 
233   // Rewrite string attribute to an enum value.
234   llvm::StringRef predicateName =
235       predicateNameAttr.cast<mlir::StringAttr>().getValue();
236   auto predicate = fir::CmpfOp::getPredicateByName(predicateName);
237   auto builder = parser.getBuilder();
238   mlir::Type i1Type = builder.getI1Type();
239   attrs[0].second = builder.getI64IntegerAttr(static_cast<int64_t>(predicate));
240   result.attributes = attrs;
241   result.addTypes({i1Type});
242   return success();
243 }
244 
245 mlir::ParseResult fir::parseCmpfOp(mlir::OpAsmParser &parser,
246                                    mlir::OperationState &result) {
247   return parseCmpOp<fir::CmpfOp>(parser, result);
248 }
249 
250 //===----------------------------------------------------------------------===//
251 // CmpcOp
252 //===----------------------------------------------------------------------===//
253 
254 void fir::buildCmpCOp(Builder *builder, OperationState &result,
255                       CmpFPredicate predicate, Value lhs, Value rhs) {
256   result.addOperands({lhs, rhs});
257   result.types.push_back(builder->getI1Type());
258   result.addAttribute(
259       fir::CmpcOp::getPredicateAttrName(),
260       builder->getI64IntegerAttr(static_cast<int64_t>(predicate)));
261 }
262 
263 static void printCmpcOp(OpAsmPrinter &p, fir::CmpcOp op) { printCmpOp(p, op); }
264 
265 mlir::ParseResult fir::parseCmpcOp(mlir::OpAsmParser &parser,
266                                    mlir::OperationState &result) {
267   return parseCmpOp<fir::CmpcOp>(parser, result);
268 }
269 
270 //===----------------------------------------------------------------------===//
271 // DispatchOp
272 //===----------------------------------------------------------------------===//
273 
274 mlir::FunctionType fir::DispatchOp::getFunctionType() {
275   auto attr = getAttr("fn_type").cast<mlir::TypeAttr>();
276   return attr.getValue().cast<mlir::FunctionType>();
277 }
278 
279 //===----------------------------------------------------------------------===//
280 // DispatchTableOp
281 //===----------------------------------------------------------------------===//
282 
283 void fir::DispatchTableOp::appendTableEntry(mlir::Operation *op) {
284   assert(mlir::isa<fir::DTEntryOp>(*op) && "operation must be a DTEntryOp");
285   auto &block = getBlock();
286   block.getOperations().insert(block.end(), op);
287 }
288 
289 //===----------------------------------------------------------------------===//
290 // EmboxOp
291 //===----------------------------------------------------------------------===//
292 
293 static mlir::ParseResult parseEmboxOp(mlir::OpAsmParser &parser,
294                                       mlir::OperationState &result) {
295   mlir::FunctionType type;
296   llvm::SmallVector<mlir::OpAsmParser::OperandType, 8> operands;
297   mlir::OpAsmParser::OperandType memref;
298   if (parser.parseOperand(memref))
299     return mlir::failure();
300   operands.push_back(memref);
301   auto &builder = parser.getBuilder();
302   if (!parser.parseOptionalLParen()) {
303     if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None) ||
304         parser.parseRParen())
305       return mlir::failure();
306     auto lens = builder.getI32IntegerAttr(operands.size());
307     result.addAttribute(fir::EmboxOp::lenpName(), lens);
308   }
309   if (!parser.parseOptionalComma()) {
310     mlir::OpAsmParser::OperandType dims;
311     if (parser.parseOperand(dims))
312       return mlir::failure();
313     operands.push_back(dims);
314   } else if (!parser.parseOptionalLSquare()) {
315     mlir::AffineMapAttr map;
316     if (parser.parseAttribute(map, fir::EmboxOp::layoutName(),
317                               result.attributes) ||
318         parser.parseRSquare())
319       return mlir::failure();
320   }
321   if (parser.parseOptionalAttrDict(result.attributes) ||
322       parser.parseColonType(type) ||
323       parser.resolveOperands(operands, type.getInputs(), parser.getNameLoc(),
324                              result.operands) ||
325       parser.addTypesToList(type.getResults(), result.types))
326     return mlir::failure();
327   return mlir::success();
328 }
329 
330 //===----------------------------------------------------------------------===//
331 // GenTypeDescOp
332 //===----------------------------------------------------------------------===//
333 
334 void fir::GenTypeDescOp::build(Builder *, OperationState &result,
335                                mlir::TypeAttr inty) {
336   result.addAttribute("in_type", inty);
337   result.addTypes(TypeDescType::get(inty.getValue()));
338 }
339 
340 //===----------------------------------------------------------------------===//
341 // GlobalOp
342 //===----------------------------------------------------------------------===//
343 
344 void fir::GlobalOp::appendInitialValue(mlir::Operation *op) {
345   getBlock().getOperations().push_back(op);
346 }
347 
348 //===----------------------------------------------------------------------===//
349 // LoadOp
350 //===----------------------------------------------------------------------===//
351 
352 /// Get the element type of a reference like type; otherwise null
353 static mlir::Type elementTypeOf(mlir::Type ref) {
354   return mlir::TypeSwitch<mlir::Type, mlir::Type>(ref)
355       .Case<ReferenceType, PointerType, HeapType>(
356           [](auto type) { return type.getEleTy(); })
357       .Default([](mlir::Type) { return mlir::Type{}; });
358 }
359 
360 mlir::ParseResult fir::LoadOp::getElementOf(mlir::Type &ele, mlir::Type ref) {
361   if ((ele = elementTypeOf(ref)))
362     return mlir::success();
363   return mlir::failure();
364 }
365 
366 //===----------------------------------------------------------------------===//
367 // LoopOp
368 //===----------------------------------------------------------------------===//
369 
370 void fir::LoopOp::build(mlir::Builder *builder, OperationState &result,
371                         mlir::Value lb, mlir::Value ub, ValueRange step,
372                         ArrayRef<NamedAttribute> attributes) {
373   if (step.empty())
374     result.addOperands({lb, ub});
375   else
376     result.addOperands({lb, ub, step[0]});
377   mlir::Region *bodyRegion = result.addRegion();
378   LoopOp::ensureTerminator(*bodyRegion, *builder, result.location);
379   bodyRegion->front().addArgument(builder->getIndexType());
380   result.addAttributes(attributes);
381   NamedAttributeList attrs(attributes);
382   if (!attrs.get(unorderedAttrName()))
383     result.addTypes(builder->getIndexType());
384 }
385 
386 static mlir::ParseResult parseLoopOp(mlir::OpAsmParser &parser,
387                                      mlir::OperationState &result) {
388   auto &builder = parser.getBuilder();
389   OpAsmParser::OperandType inductionVariable, lb, ub, step;
390   // Parse the induction variable followed by '='.
391   if (parser.parseRegionArgument(inductionVariable) || parser.parseEqual())
392     return mlir::failure();
393 
394   // Parse loop bounds.
395   mlir::Type indexType = builder.getIndexType();
396   if (parser.parseOperand(lb) ||
397       parser.resolveOperand(lb, indexType, result.operands) ||
398       parser.parseKeyword("to") || parser.parseOperand(ub) ||
399       parser.resolveOperand(ub, indexType, result.operands))
400     return mlir::failure();
401 
402   if (parser.parseOptionalKeyword(fir::LoopOp::stepAttrName())) {
403     result.addAttribute(fir::LoopOp::stepAttrName(),
404                         builder.getIntegerAttr(builder.getIndexType(), 1));
405   } else if (parser.parseOperand(step) ||
406              parser.resolveOperand(step, indexType, result.operands)) {
407     return mlir::failure();
408   }
409 
410   // Parse the optional `unordered` keyword
411   bool isUnordered = false;
412   if (!parser.parseOptionalKeyword(LoopOp::unorderedAttrName())) {
413     result.addAttribute(LoopOp::unorderedAttrName(), builder.getUnitAttr());
414     isUnordered = true;
415   }
416 
417   // Parse the body region.
418   mlir::Region *body = result.addRegion();
419   if (parser.parseRegion(*body, inductionVariable, indexType))
420     return mlir::failure();
421 
422   fir::LoopOp::ensureTerminator(*body, builder, result.location);
423 
424   // Parse the optional attribute list.
425   if (parser.parseOptionalAttrDict(result.attributes))
426     return mlir::failure();
427   if (!isUnordered)
428     result.addTypes(builder.getIndexType());
429   return mlir::success();
430 }
431 
432 fir::LoopOp fir::getForInductionVarOwner(mlir::Value val) {
433   auto ivArg = val.dyn_cast<mlir::BlockArgument>();
434   if (!ivArg)
435     return {};
436   assert(ivArg.getOwner() && "unlinked block argument");
437   auto *containingInst = ivArg.getOwner()->getParentOp();
438   return dyn_cast_or_null<fir::LoopOp>(containingInst);
439 }
440 
441 //===----------------------------------------------------------------------===//
442 // SelectOp
443 //===----------------------------------------------------------------------===//
444 
445 static constexpr llvm::StringRef getCompareOffsetAttr() {
446   return "compare_operand_offsets";
447 }
448 
449 static constexpr llvm::StringRef getTargetOffsetAttr() {
450   return "target_operand_offsets";
451 }
452 
453 template <typename A>
454 static A getSubOperands(unsigned pos, A allArgs,
455                         mlir::DenseIntElementsAttr ranges) {
456   unsigned start = 0;
457   for (unsigned i = 0; i < pos; ++i)
458     start += (*(ranges.begin() + i)).getZExtValue();
459   unsigned end = start + (*(ranges.begin() + pos)).getZExtValue();
460   return {std::next(allArgs.begin(), start), std::next(allArgs.begin(), end)};
461 }
462 
463 llvm::Optional<mlir::OperandRange> fir::SelectOp::getCompareOperands(unsigned) {
464   return {};
465 }
466 
467 llvm::Optional<llvm::ArrayRef<mlir::Value>>
468 fir::SelectOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) {
469   return {};
470 }
471 
472 llvm::Optional<mlir::OperandRange>
473 fir::SelectOp::getSuccessorOperands(unsigned oper) {
474   auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr());
475   return {getSubOperands(oper, targetArgs(), a)};
476 }
477 
478 llvm::Optional<llvm::ArrayRef<mlir::Value>>
479 fir::SelectOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands,
480                                     unsigned oper) {
481   auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr());
482   auto segments =
483       getAttrOfType<mlir::DenseIntElementsAttr>(getOperandSegmentSizeAttr());
484   return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
485 }
486 
487 bool fir::SelectOp::canEraseSuccessorOperand() { return true; }
488 
489 //===----------------------------------------------------------------------===//
490 // SelectCaseOp
491 //===----------------------------------------------------------------------===//
492 
493 llvm::Optional<mlir::OperandRange>
494 fir::SelectCaseOp::getCompareOperands(unsigned cond) {
495   auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getCompareOffsetAttr());
496   return {getSubOperands(cond, compareArgs(), a)};
497 }
498 
499 llvm::Optional<llvm::ArrayRef<mlir::Value>>
500 fir::SelectCaseOp::getCompareOperands(llvm::ArrayRef<mlir::Value> operands,
501                                       unsigned cond) {
502   auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getCompareOffsetAttr());
503   auto segments =
504       getAttrOfType<mlir::DenseIntElementsAttr>(getOperandSegmentSizeAttr());
505   return {getSubOperands(cond, getSubOperands(1, operands, segments), a)};
506 }
507 
508 llvm::Optional<mlir::OperandRange>
509 fir::SelectCaseOp::getSuccessorOperands(unsigned oper) {
510   auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr());
511   return {getSubOperands(oper, targetArgs(), a)};
512 }
513 
514 llvm::Optional<llvm::ArrayRef<mlir::Value>>
515 fir::SelectCaseOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands,
516                                         unsigned oper) {
517   auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr());
518   auto segments =
519       getAttrOfType<mlir::DenseIntElementsAttr>(getOperandSegmentSizeAttr());
520   return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
521 }
522 
523 bool fir::SelectCaseOp::canEraseSuccessorOperand() { return true; }
524 
525 // parser for fir.select_case Op
526 static mlir::ParseResult parseSelectCase(mlir::OpAsmParser &parser,
527                                          mlir::OperationState &result) {
528   mlir::OpAsmParser::OperandType selector;
529   mlir::Type type;
530   if (parseSelector(parser, result, selector, type))
531     return mlir::failure();
532 
533   llvm::SmallVector<mlir::Attribute, 8> attrs;
534   llvm::SmallVector<mlir::OpAsmParser::OperandType, 8> opers;
535   llvm::SmallVector<mlir::Block *, 8> dests;
536   llvm::SmallVector<llvm::SmallVector<mlir::Value, 8>, 8> destArgs;
537   llvm::SmallVector<int32_t, 8> argOffs;
538   int32_t offSize = 0;
539   while (true) {
540     mlir::Attribute attr;
541     mlir::Block *dest;
542     llvm::SmallVector<mlir::Value, 8> destArg;
543     llvm::SmallVector<mlir::NamedAttribute, 1> temp;
544     if (parser.parseAttribute(attr, "a", temp) || isValidCaseAttr(attr) ||
545         parser.parseComma())
546       return mlir::failure();
547     attrs.push_back(attr);
548     if (attr.dyn_cast_or_null<mlir::UnitAttr>()) {
549       argOffs.push_back(0);
550     } else if (attr.dyn_cast_or_null<fir::ClosedIntervalAttr>()) {
551       mlir::OpAsmParser::OperandType oper1;
552       mlir::OpAsmParser::OperandType oper2;
553       if (parser.parseOperand(oper1) || parser.parseComma() ||
554           parser.parseOperand(oper2) || parser.parseComma())
555         return mlir::failure();
556       opers.push_back(oper1);
557       opers.push_back(oper2);
558       argOffs.push_back(2);
559       offSize += 2;
560     } else {
561       mlir::OpAsmParser::OperandType oper;
562       if (parser.parseOperand(oper) || parser.parseComma())
563         return mlir::failure();
564       opers.push_back(oper);
565       argOffs.push_back(1);
566       ++offSize;
567     }
568     if (parser.parseSuccessorAndUseList(dest, destArg))
569       return mlir::failure();
570     dests.push_back(dest);
571     destArgs.push_back(destArg);
572     if (!parser.parseOptionalRSquare())
573       break;
574     if (parser.parseComma())
575       return mlir::failure();
576   }
577   result.addAttribute(fir::SelectCaseOp::getCasesAttr(),
578                       parser.getBuilder().getArrayAttr(attrs));
579   if (parser.resolveOperands(opers, type, result.operands))
580     return mlir::failure();
581   llvm::SmallVector<int32_t, 8> targOffs;
582   int32_t toffSize = 0;
583   const auto count = dests.size();
584   for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) {
585     result.addSuccessors(dests[i]);
586     result.addOperands(destArgs[i]);
587     auto argSize = destArgs[i].size();
588     targOffs.push_back(argSize);
589     toffSize += argSize;
590   }
591   auto &bld = parser.getBuilder();
592   result.addAttribute(fir::SelectCaseOp::getOperandSegmentSizeAttr(),
593                       bld.getI32VectorAttr({1, offSize, toffSize}));
594   result.addAttribute(getCompareOffsetAttr(), bld.getI32VectorAttr(argOffs));
595   result.addAttribute(getTargetOffsetAttr(), bld.getI32VectorAttr(targOffs));
596   return mlir::success();
597 }
598 
599 //===----------------------------------------------------------------------===//
600 // SelectRankOp
601 //===----------------------------------------------------------------------===//
602 
603 llvm::Optional<mlir::OperandRange>
604 fir::SelectRankOp::getCompareOperands(unsigned) {
605   return {};
606 }
607 
608 llvm::Optional<llvm::ArrayRef<mlir::Value>>
609 fir::SelectRankOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) {
610   return {};
611 }
612 
613 llvm::Optional<mlir::OperandRange>
614 fir::SelectRankOp::getSuccessorOperands(unsigned oper) {
615   auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr());
616   return {getSubOperands(oper, targetArgs(), a)};
617 }
618 
619 llvm::Optional<llvm::ArrayRef<mlir::Value>>
620 fir::SelectRankOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands,
621                                         unsigned oper) {
622   auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr());
623   auto segments =
624       getAttrOfType<mlir::DenseIntElementsAttr>(getOperandSegmentSizeAttr());
625   return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
626 }
627 
628 bool fir::SelectRankOp::canEraseSuccessorOperand() { return true; }
629 
630 //===----------------------------------------------------------------------===//
631 // SelectTypeOp
632 //===----------------------------------------------------------------------===//
633 
634 llvm::Optional<mlir::OperandRange>
635 fir::SelectTypeOp::getCompareOperands(unsigned) {
636   return {};
637 }
638 
639 llvm::Optional<llvm::ArrayRef<mlir::Value>>
640 fir::SelectTypeOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) {
641   return {};
642 }
643 
644 llvm::Optional<mlir::OperandRange>
645 fir::SelectTypeOp::getSuccessorOperands(unsigned oper) {
646   auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr());
647   return {getSubOperands(oper, targetArgs(), a)};
648 }
649 
650 llvm::Optional<llvm::ArrayRef<mlir::Value>>
651 fir::SelectTypeOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands,
652                                         unsigned oper) {
653   auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr());
654   auto segments =
655       getAttrOfType<mlir::DenseIntElementsAttr>(getOperandSegmentSizeAttr());
656   return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
657 }
658 
659 bool fir::SelectTypeOp::canEraseSuccessorOperand() { return true; }
660 
661 static ParseResult parseSelectType(OpAsmParser &parser,
662                                    OperationState &result) {
663   mlir::OpAsmParser::OperandType selector;
664   mlir::Type type;
665   if (parseSelector(parser, result, selector, type))
666     return mlir::failure();
667 
668   llvm::SmallVector<mlir::Attribute, 8> attrs;
669   llvm::SmallVector<mlir::Block *, 8> dests;
670   llvm::SmallVector<llvm::SmallVector<mlir::Value, 8>, 8> destArgs;
671   while (true) {
672     mlir::Attribute attr;
673     mlir::Block *dest;
674     llvm::SmallVector<mlir::Value, 8> destArg;
675     llvm::SmallVector<mlir::NamedAttribute, 1> temp;
676     if (parser.parseAttribute(attr, "a", temp) || parser.parseComma() ||
677         parser.parseSuccessorAndUseList(dest, destArg))
678       return mlir::failure();
679     attrs.push_back(attr);
680     dests.push_back(dest);
681     destArgs.push_back(destArg);
682     if (!parser.parseOptionalRSquare())
683       break;
684     if (parser.parseComma())
685       return mlir::failure();
686   }
687   auto &bld = parser.getBuilder();
688   result.addAttribute(fir::SelectTypeOp::getCasesAttr(),
689                       bld.getArrayAttr(attrs));
690   llvm::SmallVector<int32_t, 8> argOffs;
691   int32_t offSize = 0;
692   const auto count = dests.size();
693   for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) {
694     result.addSuccessors(dests[i]);
695     result.addOperands(destArgs[i]);
696     auto argSize = destArgs[i].size();
697     argOffs.push_back(argSize);
698     offSize += argSize;
699   }
700   result.addAttribute(fir::SelectTypeOp::getOperandSegmentSizeAttr(),
701                       bld.getI32VectorAttr({1, 0, offSize}));
702   result.addAttribute(getTargetOffsetAttr(), bld.getI32VectorAttr(argOffs));
703   return mlir::success();
704 }
705 
706 //===----------------------------------------------------------------------===//
707 // StoreOp
708 //===----------------------------------------------------------------------===//
709 
710 mlir::Type fir::StoreOp::elementType(mlir::Type refType) {
711   if (auto ref = refType.dyn_cast<ReferenceType>())
712     return ref.getEleTy();
713   if (auto ref = refType.dyn_cast<PointerType>())
714     return ref.getEleTy();
715   if (auto ref = refType.dyn_cast<HeapType>())
716     return ref.getEleTy();
717   return {};
718 }
719 
720 //===----------------------------------------------------------------------===//
721 // StringLitOp
722 //===----------------------------------------------------------------------===//
723 
724 bool fir::StringLitOp::isWideValue() {
725   auto eleTy = getType().cast<fir::SequenceType>().getEleTy();
726   return eleTy.cast<fir::CharacterType>().getFKind() != 1;
727 }
728 
729 //===----------------------------------------------------------------------===//
730 // WhereOp
731 //===----------------------------------------------------------------------===//
732 
733 void fir::WhereOp::build(mlir::Builder *builder, OperationState &result,
734                          mlir::Value cond, bool withElseRegion) {
735   result.addOperands(cond);
736   mlir::Region *thenRegion = result.addRegion();
737   mlir::Region *elseRegion = result.addRegion();
738   WhereOp::ensureTerminator(*thenRegion, *builder, result.location);
739   if (withElseRegion)
740     WhereOp::ensureTerminator(*elseRegion, *builder, result.location);
741 }
742 
743 static mlir::ParseResult parseWhereOp(OpAsmParser &parser,
744                                       OperationState &result) {
745   result.regions.reserve(2);
746   mlir::Region *thenRegion = result.addRegion();
747   mlir::Region *elseRegion = result.addRegion();
748 
749   auto &builder = parser.getBuilder();
750   OpAsmParser::OperandType cond;
751   mlir::Type i1Type = builder.getIntegerType(1);
752   if (parser.parseOperand(cond) ||
753       parser.resolveOperand(cond, i1Type, result.operands))
754     return mlir::failure();
755 
756   if (parser.parseRegion(*thenRegion, {}, {}))
757     return mlir::failure();
758 
759   WhereOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
760 
761   if (!parser.parseOptionalKeyword("otherwise")) {
762     if (parser.parseRegion(*elseRegion, {}, {}))
763       return mlir::failure();
764     WhereOp::ensureTerminator(*elseRegion, parser.getBuilder(),
765                               result.location);
766   }
767 
768   // Parse the optional attribute list.
769   if (parser.parseOptionalAttrDict(result.attributes))
770     return mlir::failure();
771 
772   return mlir::success();
773 }
774 
775 //===----------------------------------------------------------------------===//
776 
777 mlir::ParseResult fir::isValidCaseAttr(mlir::Attribute attr) {
778   if (attr.dyn_cast_or_null<mlir::UnitAttr>() ||
779       attr.dyn_cast_or_null<ClosedIntervalAttr>() ||
780       attr.dyn_cast_or_null<PointIntervalAttr>() ||
781       attr.dyn_cast_or_null<LowerBoundAttr>() ||
782       attr.dyn_cast_or_null<UpperBoundAttr>())
783     return mlir::success();
784   return mlir::failure();
785 }
786 
787 unsigned fir::getCaseArgumentOffset(llvm::ArrayRef<mlir::Attribute> cases,
788                                     unsigned dest) {
789   unsigned o = 0;
790   for (unsigned i = 0; i < dest; ++i) {
791     auto &attr = cases[i];
792     if (!attr.dyn_cast_or_null<mlir::UnitAttr>()) {
793       ++o;
794       if (attr.dyn_cast_or_null<ClosedIntervalAttr>())
795         ++o;
796     }
797   }
798   return o;
799 }
800 
801 mlir::ParseResult fir::parseSelector(mlir::OpAsmParser &parser,
802                                      mlir::OperationState &result,
803                                      mlir::OpAsmParser::OperandType &selector,
804                                      mlir::Type &type) {
805   if (parser.parseOperand(selector) || parser.parseColonType(type) ||
806       parser.resolveOperand(selector, type, result.operands) ||
807       parser.parseLSquare())
808     return mlir::failure();
809   return mlir::success();
810 }
811 
812 /// Generic pretty-printer of a binary operation
813 static void printBinaryOp(Operation *op, OpAsmPrinter &p) {
814   assert(op->getNumOperands() == 2 && "binary op must have two operands");
815   assert(op->getNumResults() == 1 && "binary op must have one result");
816 
817   p << op->getName() << ' ' << op->getOperand(0) << ", " << op->getOperand(1);
818   p.printOptionalAttrDict(op->getAttrs());
819   p << " : " << op->getResult(0).getType();
820 }
821 
822 /// Generic pretty-printer of an unary operation
823 static void printUnaryOp(Operation *op, OpAsmPrinter &p) {
824   assert(op->getNumOperands() == 1 && "unary op must have one operand");
825   assert(op->getNumResults() == 1 && "unary op must have one result");
826 
827   p << op->getName() << ' ' << op->getOperand(0);
828   p.printOptionalAttrDict(op->getAttrs());
829   p << " : " << op->getResult(0).getType();
830 }
831 
832 bool fir::isReferenceLike(mlir::Type type) {
833   return type.isa<fir::ReferenceType>() || type.isa<fir::HeapType>() ||
834          type.isa<fir::PointerType>();
835 }
836 
837 mlir::FuncOp fir::createFuncOp(mlir::Location loc, mlir::ModuleOp module,
838                                StringRef name, mlir::FunctionType type,
839                                llvm::ArrayRef<mlir::NamedAttribute> attrs) {
840   if (auto f = module.lookupSymbol<mlir::FuncOp>(name))
841     return f;
842   mlir::OpBuilder modBuilder(module.getBodyRegion());
843   return modBuilder.create<mlir::FuncOp>(loc, name, type, attrs);
844 }
845 
846 fir::GlobalOp fir::createGlobalOp(mlir::Location loc, mlir::ModuleOp module,
847                                   StringRef name, mlir::Type type,
848                                   llvm::ArrayRef<mlir::NamedAttribute> attrs) {
849   if (auto g = module.lookupSymbol<fir::GlobalOp>(name))
850     return g;
851   mlir::OpBuilder modBuilder(module.getBodyRegion());
852   return modBuilder.create<fir::GlobalOp>(loc, name, type, attrs);
853 }
854 
855 namespace fir {
856 
857 // Tablegen operators
858 
859 #define GET_OP_CLASSES
860 #include "flang/Optimizer/Dialect/FIROps.cpp.inc"
861 
862 } // namespace fir
863