1 //===-- FIROps.cpp --------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "flang/Optimizer/Dialect/FIROps.h" 10 #include "flang/Optimizer/Dialect/FIRAttr.h" 11 #include "flang/Optimizer/Dialect/FIROpsSupport.h" 12 #include "flang/Optimizer/Dialect/FIRType.h" 13 #include "mlir/ADT/TypeSwitch.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/IR/Diagnostics.h" 16 #include "mlir/IR/Function.h" 17 #include "mlir/IR/Module.h" 18 #include "mlir/IR/StandardTypes.h" 19 #include "mlir/IR/SymbolTable.h" 20 #include "llvm/ADT/StringSwitch.h" 21 22 using namespace fir; 23 24 /// return true if the sequence type is abstract or the record type is malformed 25 /// or contains an abstract sequence type 26 static bool verifyInType(mlir::Type inType, 27 llvm::SmallVectorImpl<llvm::StringRef> &visited) { 28 if (auto st = inType.dyn_cast<fir::SequenceType>()) { 29 auto shape = st.getShape(); 30 if (shape.size() == 0) 31 return true; 32 for (auto ext : shape) 33 if (ext < 0) 34 return true; 35 } else if (auto rt = inType.dyn_cast<fir::RecordType>()) { 36 // don't recurse if we're already visiting this one 37 if (llvm::is_contained(visited, rt.getName())) 38 return false; 39 // keep track of record types currently being visited 40 visited.push_back(rt.getName()); 41 for (auto &field : rt.getTypeList()) 42 if (verifyInType(field.second, visited)) 43 return true; 44 visited.pop_back(); 45 } else if (auto rt = inType.dyn_cast<fir::PointerType>()) { 46 return verifyInType(rt.getEleTy(), visited); 47 } 48 return false; 49 } 50 51 static bool verifyRecordLenParams(mlir::Type inType, unsigned numLenParams) { 52 if (numLenParams > 0) { 53 if (auto rt = inType.dyn_cast<fir::RecordType>()) 54 return numLenParams != rt.getNumLenParams(); 55 return true; 56 } 57 return false; 58 } 59 60 //===----------------------------------------------------------------------===// 61 // AllocaOp 62 //===----------------------------------------------------------------------===// 63 64 mlir::Type fir::AllocaOp::getAllocatedType() { 65 return getType().cast<ReferenceType>().getEleTy(); 66 } 67 68 /// Create a legal memory reference as return type 69 mlir::Type fir::AllocaOp::wrapResultType(mlir::Type intype) { 70 // FIR semantics: memory references to memory references are disallowed 71 if (intype.isa<ReferenceType>()) 72 return {}; 73 return ReferenceType::get(intype); 74 } 75 76 mlir::Type fir::AllocaOp::getRefTy(mlir::Type ty) { 77 return ReferenceType::get(ty); 78 } 79 80 //===----------------------------------------------------------------------===// 81 // AllocMemOp 82 //===----------------------------------------------------------------------===// 83 84 mlir::Type fir::AllocMemOp::getAllocatedType() { 85 return getType().cast<HeapType>().getEleTy(); 86 } 87 88 mlir::Type fir::AllocMemOp::getRefTy(mlir::Type ty) { 89 return HeapType::get(ty); 90 } 91 92 /// Create a legal heap reference as return type 93 mlir::Type fir::AllocMemOp::wrapResultType(mlir::Type intype) { 94 // Fortran semantics: C852 an entity cannot be both ALLOCATABLE and POINTER 95 // 8.5.3 note 1 prohibits ALLOCATABLE procedures as well 96 // FIR semantics: one may not allocate a memory reference value 97 if (intype.isa<ReferenceType>() || intype.isa<HeapType>() || 98 intype.isa<PointerType>() || intype.isa<FunctionType>()) 99 return {}; 100 return HeapType::get(intype); 101 } 102 103 //===----------------------------------------------------------------------===// 104 // BoxDimsOp 105 //===----------------------------------------------------------------------===// 106 107 /// Get the result types packed in a tuple tuple 108 mlir::Type fir::BoxDimsOp::getTupleType() { 109 // note: triple, but 4 is nearest power of 2 110 llvm::SmallVector<mlir::Type, 4> triple{ 111 getResult(0).getType(), getResult(1).getType(), getResult(2).getType()}; 112 return mlir::TupleType::get(triple, getContext()); 113 } 114 115 //===----------------------------------------------------------------------===// 116 // CallOp 117 //===----------------------------------------------------------------------===// 118 119 static void printCallOp(mlir::OpAsmPrinter &p, fir::CallOp &op) { 120 auto callee = op.callee(); 121 bool isDirect = callee.hasValue(); 122 p << op.getOperationName() << ' '; 123 if (isDirect) 124 p << callee.getValue(); 125 else 126 p << op.getOperand(0); 127 p << '(' << op.getOperands().drop_front(isDirect ? 0 : 1) << ')'; 128 p.printOptionalAttrDict(op.getAttrs(), {fir::CallOp::calleeAttrName()}); 129 auto resultTypes{op.getResultTypes()}; 130 llvm::SmallVector<Type, 8> argTypes( 131 llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1)); 132 p << " : " << FunctionType::get(argTypes, resultTypes, op.getContext()); 133 } 134 135 static mlir::ParseResult parseCallOp(mlir::OpAsmParser &parser, 136 mlir::OperationState &result) { 137 llvm::SmallVector<mlir::OpAsmParser::OperandType, 8> operands; 138 if (parser.parseOperandList(operands)) 139 return mlir::failure(); 140 141 llvm::SmallVector<mlir::NamedAttribute, 4> attrs; 142 mlir::SymbolRefAttr funcAttr; 143 bool isDirect = operands.empty(); 144 if (isDirect) 145 if (parser.parseAttribute(funcAttr, fir::CallOp::calleeAttrName(), attrs)) 146 return mlir::failure(); 147 148 Type type; 149 if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::Paren) || 150 parser.parseOptionalAttrDict(attrs) || parser.parseColon() || 151 parser.parseType(type)) 152 return mlir::failure(); 153 154 auto funcType = type.dyn_cast<mlir::FunctionType>(); 155 if (!funcType) 156 return parser.emitError(parser.getNameLoc(), "expected function type"); 157 if (isDirect) { 158 if (parser.resolveOperands(operands, funcType.getInputs(), 159 parser.getNameLoc(), result.operands)) 160 return mlir::failure(); 161 } else { 162 auto funcArgs = 163 llvm::ArrayRef<mlir::OpAsmParser::OperandType>(operands).drop_front(); 164 llvm::SmallVector<mlir::Value, 8> resultArgs( 165 result.operands.begin() + (result.operands.empty() ? 0 : 1), 166 result.operands.end()); 167 if (parser.resolveOperand(operands[0], funcType, result.operands) || 168 parser.resolveOperands(funcArgs, funcType.getInputs(), 169 parser.getNameLoc(), resultArgs)) 170 return mlir::failure(); 171 } 172 result.addTypes(funcType.getResults()); 173 result.attributes = attrs; 174 return mlir::success(); 175 } 176 177 //===----------------------------------------------------------------------===// 178 // CmpfOp 179 //===----------------------------------------------------------------------===// 180 181 // Note: getCmpFPredicateNames() is inline static in StandardOps/IR/Ops.cpp 182 mlir::CmpFPredicate fir::CmpfOp::getPredicateByName(llvm::StringRef name) { 183 auto pred = mlir::symbolizeCmpFPredicate(name); 184 assert(pred.hasValue() && "invalid predicate name"); 185 return pred.getValue(); 186 } 187 188 void fir::buildCmpFOp(Builder *builder, OperationState &result, 189 CmpFPredicate predicate, Value lhs, Value rhs) { 190 result.addOperands({lhs, rhs}); 191 result.types.push_back(builder->getI1Type()); 192 result.addAttribute( 193 CmpfOp::getPredicateAttrName(), 194 builder->getI64IntegerAttr(static_cast<int64_t>(predicate))); 195 } 196 197 template <typename OPTY> 198 static void printCmpOp(OpAsmPrinter &p, OPTY op) { 199 p << op.getOperationName() << ' '; 200 auto predSym = mlir::symbolizeCmpFPredicate( 201 op.template getAttrOfType<mlir::IntegerAttr>(OPTY::getPredicateAttrName()) 202 .getInt()); 203 assert(predSym.hasValue() && "invalid symbol value for predicate"); 204 p << '"' << mlir::stringifyCmpFPredicate(predSym.getValue()) << '"' << ", "; 205 p.printOperand(op.lhs()); 206 p << ", "; 207 p.printOperand(op.rhs()); 208 p.printOptionalAttrDict(op.getAttrs(), 209 /*elidedAttrs=*/{OPTY::getPredicateAttrName()}); 210 p << " : " << op.lhs().getType(); 211 } 212 213 static void printCmpfOp(OpAsmPrinter &p, CmpfOp op) { printCmpOp(p, op); } 214 215 template <typename OPTY> 216 static mlir::ParseResult parseCmpOp(mlir::OpAsmParser &parser, 217 mlir::OperationState &result) { 218 llvm::SmallVector<mlir::OpAsmParser::OperandType, 2> ops; 219 llvm::SmallVector<mlir::NamedAttribute, 4> attrs; 220 mlir::Attribute predicateNameAttr; 221 mlir::Type type; 222 if (parser.parseAttribute(predicateNameAttr, OPTY::getPredicateAttrName(), 223 attrs) || 224 parser.parseComma() || parser.parseOperandList(ops, 2) || 225 parser.parseOptionalAttrDict(attrs) || parser.parseColonType(type) || 226 parser.resolveOperands(ops, type, result.operands)) 227 return failure(); 228 229 if (!predicateNameAttr.isa<mlir::StringAttr>()) 230 return parser.emitError(parser.getNameLoc(), 231 "expected string comparison predicate attribute"); 232 233 // Rewrite string attribute to an enum value. 234 llvm::StringRef predicateName = 235 predicateNameAttr.cast<mlir::StringAttr>().getValue(); 236 auto predicate = fir::CmpfOp::getPredicateByName(predicateName); 237 auto builder = parser.getBuilder(); 238 mlir::Type i1Type = builder.getI1Type(); 239 attrs[0].second = builder.getI64IntegerAttr(static_cast<int64_t>(predicate)); 240 result.attributes = attrs; 241 result.addTypes({i1Type}); 242 return success(); 243 } 244 245 mlir::ParseResult fir::parseCmpfOp(mlir::OpAsmParser &parser, 246 mlir::OperationState &result) { 247 return parseCmpOp<fir::CmpfOp>(parser, result); 248 } 249 250 //===----------------------------------------------------------------------===// 251 // CmpcOp 252 //===----------------------------------------------------------------------===// 253 254 void fir::buildCmpCOp(Builder *builder, OperationState &result, 255 CmpFPredicate predicate, Value lhs, Value rhs) { 256 result.addOperands({lhs, rhs}); 257 result.types.push_back(builder->getI1Type()); 258 result.addAttribute( 259 fir::CmpcOp::getPredicateAttrName(), 260 builder->getI64IntegerAttr(static_cast<int64_t>(predicate))); 261 } 262 263 static void printCmpcOp(OpAsmPrinter &p, fir::CmpcOp op) { printCmpOp(p, op); } 264 265 mlir::ParseResult fir::parseCmpcOp(mlir::OpAsmParser &parser, 266 mlir::OperationState &result) { 267 return parseCmpOp<fir::CmpcOp>(parser, result); 268 } 269 270 //===----------------------------------------------------------------------===// 271 // DispatchOp 272 //===----------------------------------------------------------------------===// 273 274 mlir::FunctionType fir::DispatchOp::getFunctionType() { 275 auto attr = getAttr("fn_type").cast<mlir::TypeAttr>(); 276 return attr.getValue().cast<mlir::FunctionType>(); 277 } 278 279 //===----------------------------------------------------------------------===// 280 // DispatchTableOp 281 //===----------------------------------------------------------------------===// 282 283 void fir::DispatchTableOp::appendTableEntry(mlir::Operation *op) { 284 assert(mlir::isa<fir::DTEntryOp>(*op) && "operation must be a DTEntryOp"); 285 auto &block = getBlock(); 286 block.getOperations().insert(block.end(), op); 287 } 288 289 //===----------------------------------------------------------------------===// 290 // EmboxOp 291 //===----------------------------------------------------------------------===// 292 293 static mlir::ParseResult parseEmboxOp(mlir::OpAsmParser &parser, 294 mlir::OperationState &result) { 295 mlir::FunctionType type; 296 llvm::SmallVector<mlir::OpAsmParser::OperandType, 8> operands; 297 mlir::OpAsmParser::OperandType memref; 298 if (parser.parseOperand(memref)) 299 return mlir::failure(); 300 operands.push_back(memref); 301 auto &builder = parser.getBuilder(); 302 if (!parser.parseOptionalLParen()) { 303 if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None) || 304 parser.parseRParen()) 305 return mlir::failure(); 306 auto lens = builder.getI32IntegerAttr(operands.size()); 307 result.addAttribute(fir::EmboxOp::lenpName(), lens); 308 } 309 if (!parser.parseOptionalComma()) { 310 mlir::OpAsmParser::OperandType dims; 311 if (parser.parseOperand(dims)) 312 return mlir::failure(); 313 operands.push_back(dims); 314 } else if (!parser.parseOptionalLSquare()) { 315 mlir::AffineMapAttr map; 316 if (parser.parseAttribute(map, fir::EmboxOp::layoutName(), 317 result.attributes) || 318 parser.parseRSquare()) 319 return mlir::failure(); 320 } 321 if (parser.parseOptionalAttrDict(result.attributes) || 322 parser.parseColonType(type) || 323 parser.resolveOperands(operands, type.getInputs(), parser.getNameLoc(), 324 result.operands) || 325 parser.addTypesToList(type.getResults(), result.types)) 326 return mlir::failure(); 327 return mlir::success(); 328 } 329 330 //===----------------------------------------------------------------------===// 331 // GenTypeDescOp 332 //===----------------------------------------------------------------------===// 333 334 void fir::GenTypeDescOp::build(Builder *, OperationState &result, 335 mlir::TypeAttr inty) { 336 result.addAttribute("in_type", inty); 337 result.addTypes(TypeDescType::get(inty.getValue())); 338 } 339 340 //===----------------------------------------------------------------------===// 341 // GlobalOp 342 //===----------------------------------------------------------------------===// 343 344 void fir::GlobalOp::appendInitialValue(mlir::Operation *op) { 345 getBlock().getOperations().push_back(op); 346 } 347 348 //===----------------------------------------------------------------------===// 349 // LoadOp 350 //===----------------------------------------------------------------------===// 351 352 /// Get the element type of a reference like type; otherwise null 353 static mlir::Type elementTypeOf(mlir::Type ref) { 354 return mlir::TypeSwitch<mlir::Type, mlir::Type>(ref) 355 .Case<ReferenceType, PointerType, HeapType>( 356 [](auto type) { return type.getEleTy(); }) 357 .Default([](mlir::Type) { return mlir::Type{}; }); 358 } 359 360 mlir::ParseResult fir::LoadOp::getElementOf(mlir::Type &ele, mlir::Type ref) { 361 if ((ele = elementTypeOf(ref))) 362 return mlir::success(); 363 return mlir::failure(); 364 } 365 366 //===----------------------------------------------------------------------===// 367 // LoopOp 368 //===----------------------------------------------------------------------===// 369 370 void fir::LoopOp::build(mlir::Builder *builder, OperationState &result, 371 mlir::Value lb, mlir::Value ub, ValueRange step, 372 ArrayRef<NamedAttribute> attributes) { 373 if (step.empty()) 374 result.addOperands({lb, ub}); 375 else 376 result.addOperands({lb, ub, step[0]}); 377 mlir::Region *bodyRegion = result.addRegion(); 378 LoopOp::ensureTerminator(*bodyRegion, *builder, result.location); 379 bodyRegion->front().addArgument(builder->getIndexType()); 380 result.addAttributes(attributes); 381 NamedAttributeList attrs(attributes); 382 if (!attrs.get(unorderedAttrName())) 383 result.addTypes(builder->getIndexType()); 384 } 385 386 static mlir::ParseResult parseLoopOp(mlir::OpAsmParser &parser, 387 mlir::OperationState &result) { 388 auto &builder = parser.getBuilder(); 389 OpAsmParser::OperandType inductionVariable, lb, ub, step; 390 // Parse the induction variable followed by '='. 391 if (parser.parseRegionArgument(inductionVariable) || parser.parseEqual()) 392 return mlir::failure(); 393 394 // Parse loop bounds. 395 mlir::Type indexType = builder.getIndexType(); 396 if (parser.parseOperand(lb) || 397 parser.resolveOperand(lb, indexType, result.operands) || 398 parser.parseKeyword("to") || parser.parseOperand(ub) || 399 parser.resolveOperand(ub, indexType, result.operands)) 400 return mlir::failure(); 401 402 if (parser.parseOptionalKeyword(fir::LoopOp::stepAttrName())) { 403 result.addAttribute(fir::LoopOp::stepAttrName(), 404 builder.getIntegerAttr(builder.getIndexType(), 1)); 405 } else if (parser.parseOperand(step) || 406 parser.resolveOperand(step, indexType, result.operands)) { 407 return mlir::failure(); 408 } 409 410 // Parse the optional `unordered` keyword 411 bool isUnordered = false; 412 if (!parser.parseOptionalKeyword(LoopOp::unorderedAttrName())) { 413 result.addAttribute(LoopOp::unorderedAttrName(), builder.getUnitAttr()); 414 isUnordered = true; 415 } 416 417 // Parse the body region. 418 mlir::Region *body = result.addRegion(); 419 if (parser.parseRegion(*body, inductionVariable, indexType)) 420 return mlir::failure(); 421 422 fir::LoopOp::ensureTerminator(*body, builder, result.location); 423 424 // Parse the optional attribute list. 425 if (parser.parseOptionalAttrDict(result.attributes)) 426 return mlir::failure(); 427 if (!isUnordered) 428 result.addTypes(builder.getIndexType()); 429 return mlir::success(); 430 } 431 432 fir::LoopOp fir::getForInductionVarOwner(mlir::Value val) { 433 auto ivArg = val.dyn_cast<mlir::BlockArgument>(); 434 if (!ivArg) 435 return {}; 436 assert(ivArg.getOwner() && "unlinked block argument"); 437 auto *containingInst = ivArg.getOwner()->getParentOp(); 438 return dyn_cast_or_null<fir::LoopOp>(containingInst); 439 } 440 441 //===----------------------------------------------------------------------===// 442 // SelectOp 443 //===----------------------------------------------------------------------===// 444 445 static constexpr llvm::StringRef getCompareOffsetAttr() { 446 return "compare_operand_offsets"; 447 } 448 449 static constexpr llvm::StringRef getTargetOffsetAttr() { 450 return "target_operand_offsets"; 451 } 452 453 template <typename A> 454 static A getSubOperands(unsigned pos, A allArgs, 455 mlir::DenseIntElementsAttr ranges) { 456 unsigned start = 0; 457 for (unsigned i = 0; i < pos; ++i) 458 start += (*(ranges.begin() + i)).getZExtValue(); 459 unsigned end = start + (*(ranges.begin() + pos)).getZExtValue(); 460 return {std::next(allArgs.begin(), start), std::next(allArgs.begin(), end)}; 461 } 462 463 llvm::Optional<mlir::OperandRange> fir::SelectOp::getCompareOperands(unsigned) { 464 return {}; 465 } 466 467 llvm::Optional<llvm::ArrayRef<mlir::Value>> 468 fir::SelectOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) { 469 return {}; 470 } 471 472 llvm::Optional<mlir::OperandRange> 473 fir::SelectOp::getSuccessorOperands(unsigned oper) { 474 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 475 return {getSubOperands(oper, targetArgs(), a)}; 476 } 477 478 llvm::Optional<llvm::ArrayRef<mlir::Value>> 479 fir::SelectOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, 480 unsigned oper) { 481 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 482 auto segments = 483 getAttrOfType<mlir::DenseIntElementsAttr>(getOperandSegmentSizeAttr()); 484 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 485 } 486 487 bool fir::SelectOp::canEraseSuccessorOperand() { return true; } 488 489 //===----------------------------------------------------------------------===// 490 // SelectCaseOp 491 //===----------------------------------------------------------------------===// 492 493 llvm::Optional<mlir::OperandRange> 494 fir::SelectCaseOp::getCompareOperands(unsigned cond) { 495 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getCompareOffsetAttr()); 496 return {getSubOperands(cond, compareArgs(), a)}; 497 } 498 499 llvm::Optional<llvm::ArrayRef<mlir::Value>> 500 fir::SelectCaseOp::getCompareOperands(llvm::ArrayRef<mlir::Value> operands, 501 unsigned cond) { 502 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getCompareOffsetAttr()); 503 auto segments = 504 getAttrOfType<mlir::DenseIntElementsAttr>(getOperandSegmentSizeAttr()); 505 return {getSubOperands(cond, getSubOperands(1, operands, segments), a)}; 506 } 507 508 llvm::Optional<mlir::OperandRange> 509 fir::SelectCaseOp::getSuccessorOperands(unsigned oper) { 510 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 511 return {getSubOperands(oper, targetArgs(), a)}; 512 } 513 514 llvm::Optional<llvm::ArrayRef<mlir::Value>> 515 fir::SelectCaseOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, 516 unsigned oper) { 517 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 518 auto segments = 519 getAttrOfType<mlir::DenseIntElementsAttr>(getOperandSegmentSizeAttr()); 520 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 521 } 522 523 bool fir::SelectCaseOp::canEraseSuccessorOperand() { return true; } 524 525 // parser for fir.select_case Op 526 static mlir::ParseResult parseSelectCase(mlir::OpAsmParser &parser, 527 mlir::OperationState &result) { 528 mlir::OpAsmParser::OperandType selector; 529 mlir::Type type; 530 if (parseSelector(parser, result, selector, type)) 531 return mlir::failure(); 532 533 llvm::SmallVector<mlir::Attribute, 8> attrs; 534 llvm::SmallVector<mlir::OpAsmParser::OperandType, 8> opers; 535 llvm::SmallVector<mlir::Block *, 8> dests; 536 llvm::SmallVector<llvm::SmallVector<mlir::Value, 8>, 8> destArgs; 537 llvm::SmallVector<int32_t, 8> argOffs; 538 int32_t offSize = 0; 539 while (true) { 540 mlir::Attribute attr; 541 mlir::Block *dest; 542 llvm::SmallVector<mlir::Value, 8> destArg; 543 llvm::SmallVector<mlir::NamedAttribute, 1> temp; 544 if (parser.parseAttribute(attr, "a", temp) || isValidCaseAttr(attr) || 545 parser.parseComma()) 546 return mlir::failure(); 547 attrs.push_back(attr); 548 if (attr.dyn_cast_or_null<mlir::UnitAttr>()) { 549 argOffs.push_back(0); 550 } else if (attr.dyn_cast_or_null<fir::ClosedIntervalAttr>()) { 551 mlir::OpAsmParser::OperandType oper1; 552 mlir::OpAsmParser::OperandType oper2; 553 if (parser.parseOperand(oper1) || parser.parseComma() || 554 parser.parseOperand(oper2) || parser.parseComma()) 555 return mlir::failure(); 556 opers.push_back(oper1); 557 opers.push_back(oper2); 558 argOffs.push_back(2); 559 offSize += 2; 560 } else { 561 mlir::OpAsmParser::OperandType oper; 562 if (parser.parseOperand(oper) || parser.parseComma()) 563 return mlir::failure(); 564 opers.push_back(oper); 565 argOffs.push_back(1); 566 ++offSize; 567 } 568 if (parser.parseSuccessorAndUseList(dest, destArg)) 569 return mlir::failure(); 570 dests.push_back(dest); 571 destArgs.push_back(destArg); 572 if (!parser.parseOptionalRSquare()) 573 break; 574 if (parser.parseComma()) 575 return mlir::failure(); 576 } 577 result.addAttribute(fir::SelectCaseOp::getCasesAttr(), 578 parser.getBuilder().getArrayAttr(attrs)); 579 if (parser.resolveOperands(opers, type, result.operands)) 580 return mlir::failure(); 581 llvm::SmallVector<int32_t, 8> targOffs; 582 int32_t toffSize = 0; 583 const auto count = dests.size(); 584 for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { 585 result.addSuccessors(dests[i]); 586 result.addOperands(destArgs[i]); 587 auto argSize = destArgs[i].size(); 588 targOffs.push_back(argSize); 589 toffSize += argSize; 590 } 591 auto &bld = parser.getBuilder(); 592 result.addAttribute(fir::SelectCaseOp::getOperandSegmentSizeAttr(), 593 bld.getI32VectorAttr({1, offSize, toffSize})); 594 result.addAttribute(getCompareOffsetAttr(), bld.getI32VectorAttr(argOffs)); 595 result.addAttribute(getTargetOffsetAttr(), bld.getI32VectorAttr(targOffs)); 596 return mlir::success(); 597 } 598 599 //===----------------------------------------------------------------------===// 600 // SelectRankOp 601 //===----------------------------------------------------------------------===// 602 603 llvm::Optional<mlir::OperandRange> 604 fir::SelectRankOp::getCompareOperands(unsigned) { 605 return {}; 606 } 607 608 llvm::Optional<llvm::ArrayRef<mlir::Value>> 609 fir::SelectRankOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) { 610 return {}; 611 } 612 613 llvm::Optional<mlir::OperandRange> 614 fir::SelectRankOp::getSuccessorOperands(unsigned oper) { 615 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 616 return {getSubOperands(oper, targetArgs(), a)}; 617 } 618 619 llvm::Optional<llvm::ArrayRef<mlir::Value>> 620 fir::SelectRankOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, 621 unsigned oper) { 622 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 623 auto segments = 624 getAttrOfType<mlir::DenseIntElementsAttr>(getOperandSegmentSizeAttr()); 625 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 626 } 627 628 bool fir::SelectRankOp::canEraseSuccessorOperand() { return true; } 629 630 //===----------------------------------------------------------------------===// 631 // SelectTypeOp 632 //===----------------------------------------------------------------------===// 633 634 llvm::Optional<mlir::OperandRange> 635 fir::SelectTypeOp::getCompareOperands(unsigned) { 636 return {}; 637 } 638 639 llvm::Optional<llvm::ArrayRef<mlir::Value>> 640 fir::SelectTypeOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) { 641 return {}; 642 } 643 644 llvm::Optional<mlir::OperandRange> 645 fir::SelectTypeOp::getSuccessorOperands(unsigned oper) { 646 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 647 return {getSubOperands(oper, targetArgs(), a)}; 648 } 649 650 llvm::Optional<llvm::ArrayRef<mlir::Value>> 651 fir::SelectTypeOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, 652 unsigned oper) { 653 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 654 auto segments = 655 getAttrOfType<mlir::DenseIntElementsAttr>(getOperandSegmentSizeAttr()); 656 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 657 } 658 659 bool fir::SelectTypeOp::canEraseSuccessorOperand() { return true; } 660 661 static ParseResult parseSelectType(OpAsmParser &parser, 662 OperationState &result) { 663 mlir::OpAsmParser::OperandType selector; 664 mlir::Type type; 665 if (parseSelector(parser, result, selector, type)) 666 return mlir::failure(); 667 668 llvm::SmallVector<mlir::Attribute, 8> attrs; 669 llvm::SmallVector<mlir::Block *, 8> dests; 670 llvm::SmallVector<llvm::SmallVector<mlir::Value, 8>, 8> destArgs; 671 while (true) { 672 mlir::Attribute attr; 673 mlir::Block *dest; 674 llvm::SmallVector<mlir::Value, 8> destArg; 675 llvm::SmallVector<mlir::NamedAttribute, 1> temp; 676 if (parser.parseAttribute(attr, "a", temp) || parser.parseComma() || 677 parser.parseSuccessorAndUseList(dest, destArg)) 678 return mlir::failure(); 679 attrs.push_back(attr); 680 dests.push_back(dest); 681 destArgs.push_back(destArg); 682 if (!parser.parseOptionalRSquare()) 683 break; 684 if (parser.parseComma()) 685 return mlir::failure(); 686 } 687 auto &bld = parser.getBuilder(); 688 result.addAttribute(fir::SelectTypeOp::getCasesAttr(), 689 bld.getArrayAttr(attrs)); 690 llvm::SmallVector<int32_t, 8> argOffs; 691 int32_t offSize = 0; 692 const auto count = dests.size(); 693 for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { 694 result.addSuccessors(dests[i]); 695 result.addOperands(destArgs[i]); 696 auto argSize = destArgs[i].size(); 697 argOffs.push_back(argSize); 698 offSize += argSize; 699 } 700 result.addAttribute(fir::SelectTypeOp::getOperandSegmentSizeAttr(), 701 bld.getI32VectorAttr({1, 0, offSize})); 702 result.addAttribute(getTargetOffsetAttr(), bld.getI32VectorAttr(argOffs)); 703 return mlir::success(); 704 } 705 706 //===----------------------------------------------------------------------===// 707 // StoreOp 708 //===----------------------------------------------------------------------===// 709 710 mlir::Type fir::StoreOp::elementType(mlir::Type refType) { 711 if (auto ref = refType.dyn_cast<ReferenceType>()) 712 return ref.getEleTy(); 713 if (auto ref = refType.dyn_cast<PointerType>()) 714 return ref.getEleTy(); 715 if (auto ref = refType.dyn_cast<HeapType>()) 716 return ref.getEleTy(); 717 return {}; 718 } 719 720 //===----------------------------------------------------------------------===// 721 // StringLitOp 722 //===----------------------------------------------------------------------===// 723 724 bool fir::StringLitOp::isWideValue() { 725 auto eleTy = getType().cast<fir::SequenceType>().getEleTy(); 726 return eleTy.cast<fir::CharacterType>().getFKind() != 1; 727 } 728 729 //===----------------------------------------------------------------------===// 730 // WhereOp 731 //===----------------------------------------------------------------------===// 732 733 void fir::WhereOp::build(mlir::Builder *builder, OperationState &result, 734 mlir::Value cond, bool withElseRegion) { 735 result.addOperands(cond); 736 mlir::Region *thenRegion = result.addRegion(); 737 mlir::Region *elseRegion = result.addRegion(); 738 WhereOp::ensureTerminator(*thenRegion, *builder, result.location); 739 if (withElseRegion) 740 WhereOp::ensureTerminator(*elseRegion, *builder, result.location); 741 } 742 743 static mlir::ParseResult parseWhereOp(OpAsmParser &parser, 744 OperationState &result) { 745 result.regions.reserve(2); 746 mlir::Region *thenRegion = result.addRegion(); 747 mlir::Region *elseRegion = result.addRegion(); 748 749 auto &builder = parser.getBuilder(); 750 OpAsmParser::OperandType cond; 751 mlir::Type i1Type = builder.getIntegerType(1); 752 if (parser.parseOperand(cond) || 753 parser.resolveOperand(cond, i1Type, result.operands)) 754 return mlir::failure(); 755 756 if (parser.parseRegion(*thenRegion, {}, {})) 757 return mlir::failure(); 758 759 WhereOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location); 760 761 if (!parser.parseOptionalKeyword("otherwise")) { 762 if (parser.parseRegion(*elseRegion, {}, {})) 763 return mlir::failure(); 764 WhereOp::ensureTerminator(*elseRegion, parser.getBuilder(), 765 result.location); 766 } 767 768 // Parse the optional attribute list. 769 if (parser.parseOptionalAttrDict(result.attributes)) 770 return mlir::failure(); 771 772 return mlir::success(); 773 } 774 775 //===----------------------------------------------------------------------===// 776 777 mlir::ParseResult fir::isValidCaseAttr(mlir::Attribute attr) { 778 if (attr.dyn_cast_or_null<mlir::UnitAttr>() || 779 attr.dyn_cast_or_null<ClosedIntervalAttr>() || 780 attr.dyn_cast_or_null<PointIntervalAttr>() || 781 attr.dyn_cast_or_null<LowerBoundAttr>() || 782 attr.dyn_cast_or_null<UpperBoundAttr>()) 783 return mlir::success(); 784 return mlir::failure(); 785 } 786 787 unsigned fir::getCaseArgumentOffset(llvm::ArrayRef<mlir::Attribute> cases, 788 unsigned dest) { 789 unsigned o = 0; 790 for (unsigned i = 0; i < dest; ++i) { 791 auto &attr = cases[i]; 792 if (!attr.dyn_cast_or_null<mlir::UnitAttr>()) { 793 ++o; 794 if (attr.dyn_cast_or_null<ClosedIntervalAttr>()) 795 ++o; 796 } 797 } 798 return o; 799 } 800 801 mlir::ParseResult fir::parseSelector(mlir::OpAsmParser &parser, 802 mlir::OperationState &result, 803 mlir::OpAsmParser::OperandType &selector, 804 mlir::Type &type) { 805 if (parser.parseOperand(selector) || parser.parseColonType(type) || 806 parser.resolveOperand(selector, type, result.operands) || 807 parser.parseLSquare()) 808 return mlir::failure(); 809 return mlir::success(); 810 } 811 812 /// Generic pretty-printer of a binary operation 813 static void printBinaryOp(Operation *op, OpAsmPrinter &p) { 814 assert(op->getNumOperands() == 2 && "binary op must have two operands"); 815 assert(op->getNumResults() == 1 && "binary op must have one result"); 816 817 p << op->getName() << ' ' << op->getOperand(0) << ", " << op->getOperand(1); 818 p.printOptionalAttrDict(op->getAttrs()); 819 p << " : " << op->getResult(0).getType(); 820 } 821 822 /// Generic pretty-printer of an unary operation 823 static void printUnaryOp(Operation *op, OpAsmPrinter &p) { 824 assert(op->getNumOperands() == 1 && "unary op must have one operand"); 825 assert(op->getNumResults() == 1 && "unary op must have one result"); 826 827 p << op->getName() << ' ' << op->getOperand(0); 828 p.printOptionalAttrDict(op->getAttrs()); 829 p << " : " << op->getResult(0).getType(); 830 } 831 832 bool fir::isReferenceLike(mlir::Type type) { 833 return type.isa<fir::ReferenceType>() || type.isa<fir::HeapType>() || 834 type.isa<fir::PointerType>(); 835 } 836 837 mlir::FuncOp fir::createFuncOp(mlir::Location loc, mlir::ModuleOp module, 838 StringRef name, mlir::FunctionType type, 839 llvm::ArrayRef<mlir::NamedAttribute> attrs) { 840 if (auto f = module.lookupSymbol<mlir::FuncOp>(name)) 841 return f; 842 mlir::OpBuilder modBuilder(module.getBodyRegion()); 843 return modBuilder.create<mlir::FuncOp>(loc, name, type, attrs); 844 } 845 846 fir::GlobalOp fir::createGlobalOp(mlir::Location loc, mlir::ModuleOp module, 847 StringRef name, mlir::Type type, 848 llvm::ArrayRef<mlir::NamedAttribute> attrs) { 849 if (auto g = module.lookupSymbol<fir::GlobalOp>(name)) 850 return g; 851 mlir::OpBuilder modBuilder(module.getBodyRegion()); 852 return modBuilder.create<fir::GlobalOp>(loc, name, type, attrs); 853 } 854 855 namespace fir { 856 857 // Tablegen operators 858 859 #define GET_OP_CLASSES 860 #include "flang/Optimizer/Dialect/FIROps.cpp.inc" 861 862 } // namespace fir 863