1 //===-- FIROps.cpp --------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "flang/Optimizer/Dialect/FIROps.h" 10 #include "flang/Optimizer/Dialect/FIRAttr.h" 11 #include "flang/Optimizer/Dialect/FIROpsSupport.h" 12 #include "flang/Optimizer/Dialect/FIRType.h" 13 #include "mlir/Dialect/CommonFolders.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/IR/Diagnostics.h" 16 #include "mlir/IR/Function.h" 17 #include "mlir/IR/Matchers.h" 18 #include "mlir/IR/Module.h" 19 #include "llvm/ADT/StringSwitch.h" 20 #include "llvm/ADT/TypeSwitch.h" 21 22 using namespace fir; 23 24 /// Return true if a sequence type is of some incomplete size or a record type 25 /// is malformed or contains an incomplete sequence type. An incomplete sequence 26 /// type is one with more unknown extents in the type than have been provided 27 /// via `dynamicExtents`. Sequence types with an unknown rank are incomplete by 28 /// definition. 29 static bool verifyInType(mlir::Type inType, 30 llvm::SmallVectorImpl<llvm::StringRef> &visited, 31 unsigned dynamicExtents = 0) { 32 if (auto st = inType.dyn_cast<fir::SequenceType>()) { 33 auto shape = st.getShape(); 34 if (shape.size() == 0) 35 return true; 36 for (std::size_t i = 0, end{shape.size()}; i < end; ++i) { 37 if (shape[i] != fir::SequenceType::getUnknownExtent()) 38 continue; 39 if (dynamicExtents-- == 0) 40 return true; 41 } 42 } else if (auto rt = inType.dyn_cast<fir::RecordType>()) { 43 // don't recurse if we're already visiting this one 44 if (llvm::is_contained(visited, rt.getName())) 45 return false; 46 // keep track of record types currently being visited 47 visited.push_back(rt.getName()); 48 for (auto &field : rt.getTypeList()) 49 if (verifyInType(field.second, visited)) 50 return true; 51 visited.pop_back(); 52 } else if (auto rt = inType.dyn_cast<fir::PointerType>()) { 53 return verifyInType(rt.getEleTy(), visited); 54 } 55 return false; 56 } 57 58 static bool verifyRecordLenParams(mlir::Type inType, unsigned numLenParams) { 59 if (numLenParams > 0) { 60 if (auto rt = inType.dyn_cast<fir::RecordType>()) 61 return numLenParams != rt.getNumLenParams(); 62 return true; 63 } 64 return false; 65 } 66 67 //===----------------------------------------------------------------------===// 68 // AddfOp 69 //===----------------------------------------------------------------------===// 70 71 mlir::OpFoldResult fir::AddfOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 72 return mlir::constFoldBinaryOp<FloatAttr>( 73 opnds, [](APFloat a, APFloat b) { return a + b; }); 74 } 75 76 //===----------------------------------------------------------------------===// 77 // AllocaOp 78 //===----------------------------------------------------------------------===// 79 80 mlir::Type fir::AllocaOp::getAllocatedType() { 81 return getType().cast<ReferenceType>().getEleTy(); 82 } 83 84 /// Create a legal memory reference as return type 85 mlir::Type fir::AllocaOp::wrapResultType(mlir::Type intype) { 86 // FIR semantics: memory references to memory references are disallowed 87 if (intype.isa<ReferenceType>()) 88 return {}; 89 return ReferenceType::get(intype); 90 } 91 92 mlir::Type fir::AllocaOp::getRefTy(mlir::Type ty) { 93 return ReferenceType::get(ty); 94 } 95 96 //===----------------------------------------------------------------------===// 97 // AllocMemOp 98 //===----------------------------------------------------------------------===// 99 100 mlir::Type fir::AllocMemOp::getAllocatedType() { 101 return getType().cast<HeapType>().getEleTy(); 102 } 103 104 mlir::Type fir::AllocMemOp::getRefTy(mlir::Type ty) { 105 return HeapType::get(ty); 106 } 107 108 /// Create a legal heap reference as return type 109 mlir::Type fir::AllocMemOp::wrapResultType(mlir::Type intype) { 110 // Fortran semantics: C852 an entity cannot be both ALLOCATABLE and POINTER 111 // 8.5.3 note 1 prohibits ALLOCATABLE procedures as well 112 // FIR semantics: one may not allocate a memory reference value 113 if (intype.isa<ReferenceType>() || intype.isa<HeapType>() || 114 intype.isa<PointerType>() || intype.isa<FunctionType>()) 115 return {}; 116 return HeapType::get(intype); 117 } 118 119 //===----------------------------------------------------------------------===// 120 // BoxAddrOp 121 //===----------------------------------------------------------------------===// 122 123 mlir::OpFoldResult fir::BoxAddrOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 124 if (auto v = val().getDefiningOp()) { 125 if (auto box = dyn_cast<fir::EmboxOp>(v)) 126 return box.memref(); 127 if (auto box = dyn_cast<fir::EmboxCharOp>(v)) 128 return box.memref(); 129 } 130 return {}; 131 } 132 133 //===----------------------------------------------------------------------===// 134 // BoxCharLenOp 135 //===----------------------------------------------------------------------===// 136 137 mlir::OpFoldResult 138 fir::BoxCharLenOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 139 if (auto v = val().getDefiningOp()) { 140 if (auto box = dyn_cast<fir::EmboxCharOp>(v)) 141 return box.len(); 142 } 143 return {}; 144 } 145 146 //===----------------------------------------------------------------------===// 147 // BoxDimsOp 148 //===----------------------------------------------------------------------===// 149 150 /// Get the result types packed in a tuple tuple 151 mlir::Type fir::BoxDimsOp::getTupleType() { 152 // note: triple, but 4 is nearest power of 2 153 llvm::SmallVector<mlir::Type, 4> triple{ 154 getResult(0).getType(), getResult(1).getType(), getResult(2).getType()}; 155 return mlir::TupleType::get(triple, getContext()); 156 } 157 158 //===----------------------------------------------------------------------===// 159 // CallOp 160 //===----------------------------------------------------------------------===// 161 162 static void printCallOp(mlir::OpAsmPrinter &p, fir::CallOp &op) { 163 auto callee = op.callee(); 164 bool isDirect = callee.hasValue(); 165 p << op.getOperationName() << ' '; 166 if (isDirect) 167 p << callee.getValue(); 168 else 169 p << op.getOperand(0); 170 p << '(' << op.getOperands().drop_front(isDirect ? 0 : 1) << ')'; 171 p.printOptionalAttrDict(op.getAttrs(), {fir::CallOp::calleeAttrName()}); 172 auto resultTypes{op.getResultTypes()}; 173 llvm::SmallVector<Type, 8> argTypes( 174 llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1)); 175 p << " : " << FunctionType::get(argTypes, resultTypes, op.getContext()); 176 } 177 178 static mlir::ParseResult parseCallOp(mlir::OpAsmParser &parser, 179 mlir::OperationState &result) { 180 llvm::SmallVector<mlir::OpAsmParser::OperandType, 8> operands; 181 if (parser.parseOperandList(operands)) 182 return mlir::failure(); 183 184 mlir::NamedAttrList attrs; 185 mlir::SymbolRefAttr funcAttr; 186 bool isDirect = operands.empty(); 187 if (isDirect) 188 if (parser.parseAttribute(funcAttr, fir::CallOp::calleeAttrName(), attrs)) 189 return mlir::failure(); 190 191 Type type; 192 if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::Paren) || 193 parser.parseOptionalAttrDict(attrs) || parser.parseColon() || 194 parser.parseType(type)) 195 return mlir::failure(); 196 197 auto funcType = type.dyn_cast<mlir::FunctionType>(); 198 if (!funcType) 199 return parser.emitError(parser.getNameLoc(), "expected function type"); 200 if (isDirect) { 201 if (parser.resolveOperands(operands, funcType.getInputs(), 202 parser.getNameLoc(), result.operands)) 203 return mlir::failure(); 204 } else { 205 auto funcArgs = 206 llvm::ArrayRef<mlir::OpAsmParser::OperandType>(operands).drop_front(); 207 llvm::SmallVector<mlir::Value, 8> resultArgs( 208 result.operands.begin() + (result.operands.empty() ? 0 : 1), 209 result.operands.end()); 210 if (parser.resolveOperand(operands[0], funcType, result.operands) || 211 parser.resolveOperands(funcArgs, funcType.getInputs(), 212 parser.getNameLoc(), resultArgs)) 213 return mlir::failure(); 214 } 215 result.addTypes(funcType.getResults()); 216 result.attributes = attrs; 217 return mlir::success(); 218 } 219 220 //===----------------------------------------------------------------------===// 221 // CmpfOp 222 //===----------------------------------------------------------------------===// 223 224 // Note: getCmpFPredicateNames() is inline static in StandardOps/IR/Ops.cpp 225 mlir::CmpFPredicate fir::CmpfOp::getPredicateByName(llvm::StringRef name) { 226 auto pred = mlir::symbolizeCmpFPredicate(name); 227 assert(pred.hasValue() && "invalid predicate name"); 228 return pred.getValue(); 229 } 230 231 void fir::buildCmpFOp(OpBuilder &builder, OperationState &result, 232 CmpFPredicate predicate, Value lhs, Value rhs) { 233 result.addOperands({lhs, rhs}); 234 result.types.push_back(builder.getI1Type()); 235 result.addAttribute( 236 CmpfOp::getPredicateAttrName(), 237 builder.getI64IntegerAttr(static_cast<int64_t>(predicate))); 238 } 239 240 template <typename OPTY> 241 static void printCmpOp(OpAsmPrinter &p, OPTY op) { 242 p << op.getOperationName() << ' '; 243 auto predSym = mlir::symbolizeCmpFPredicate( 244 op.template getAttrOfType<mlir::IntegerAttr>(OPTY::getPredicateAttrName()) 245 .getInt()); 246 assert(predSym.hasValue() && "invalid symbol value for predicate"); 247 p << '"' << mlir::stringifyCmpFPredicate(predSym.getValue()) << '"' << ", "; 248 p.printOperand(op.lhs()); 249 p << ", "; 250 p.printOperand(op.rhs()); 251 p.printOptionalAttrDict(op.getAttrs(), 252 /*elidedAttrs=*/{OPTY::getPredicateAttrName()}); 253 p << " : " << op.lhs().getType(); 254 } 255 256 static void printCmpfOp(OpAsmPrinter &p, CmpfOp op) { printCmpOp(p, op); } 257 258 template <typename OPTY> 259 static mlir::ParseResult parseCmpOp(mlir::OpAsmParser &parser, 260 mlir::OperationState &result) { 261 llvm::SmallVector<mlir::OpAsmParser::OperandType, 2> ops; 262 mlir::NamedAttrList attrs; 263 mlir::Attribute predicateNameAttr; 264 mlir::Type type; 265 if (parser.parseAttribute(predicateNameAttr, OPTY::getPredicateAttrName(), 266 attrs) || 267 parser.parseComma() || parser.parseOperandList(ops, 2) || 268 parser.parseOptionalAttrDict(attrs) || parser.parseColonType(type) || 269 parser.resolveOperands(ops, type, result.operands)) 270 return failure(); 271 272 if (!predicateNameAttr.isa<mlir::StringAttr>()) 273 return parser.emitError(parser.getNameLoc(), 274 "expected string comparison predicate attribute"); 275 276 // Rewrite string attribute to an enum value. 277 llvm::StringRef predicateName = 278 predicateNameAttr.cast<mlir::StringAttr>().getValue(); 279 auto predicate = fir::CmpfOp::getPredicateByName(predicateName); 280 auto builder = parser.getBuilder(); 281 mlir::Type i1Type = builder.getI1Type(); 282 attrs.set(OPTY::getPredicateAttrName(), 283 builder.getI64IntegerAttr(static_cast<int64_t>(predicate))); 284 result.attributes = attrs; 285 result.addTypes({i1Type}); 286 return success(); 287 } 288 289 mlir::ParseResult fir::parseCmpfOp(mlir::OpAsmParser &parser, 290 mlir::OperationState &result) { 291 return parseCmpOp<fir::CmpfOp>(parser, result); 292 } 293 294 //===----------------------------------------------------------------------===// 295 // CmpcOp 296 //===----------------------------------------------------------------------===// 297 298 void fir::buildCmpCOp(OpBuilder &builder, OperationState &result, 299 CmpFPredicate predicate, Value lhs, Value rhs) { 300 result.addOperands({lhs, rhs}); 301 result.types.push_back(builder.getI1Type()); 302 result.addAttribute( 303 fir::CmpcOp::getPredicateAttrName(), 304 builder.getI64IntegerAttr(static_cast<int64_t>(predicate))); 305 } 306 307 static void printCmpcOp(OpAsmPrinter &p, fir::CmpcOp op) { printCmpOp(p, op); } 308 309 mlir::ParseResult fir::parseCmpcOp(mlir::OpAsmParser &parser, 310 mlir::OperationState &result) { 311 return parseCmpOp<fir::CmpcOp>(parser, result); 312 } 313 314 //===----------------------------------------------------------------------===// 315 // ConvertOp 316 //===----------------------------------------------------------------------===// 317 318 mlir::OpFoldResult fir::ConvertOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 319 if (value().getType() == getType()) 320 return value(); 321 if (matchPattern(value(), m_Op<fir::ConvertOp>())) { 322 auto inner = cast<fir::ConvertOp>(value().getDefiningOp()); 323 // (convert (convert 'a : logical -> i1) : i1 -> logical) ==> forward 'a 324 if (auto toTy = getType().dyn_cast<fir::LogicalType>()) 325 if (auto fromTy = inner.value().getType().dyn_cast<fir::LogicalType>()) 326 if (inner.getType().isa<mlir::IntegerType>() && (toTy == fromTy)) 327 return inner.value(); 328 // (convert (convert 'a : i1 -> logical) : logical -> i1) ==> forward 'a 329 if (auto toTy = getType().dyn_cast<mlir::IntegerType>()) 330 if (auto fromTy = inner.value().getType().dyn_cast<mlir::IntegerType>()) 331 if (inner.getType().isa<fir::LogicalType>() && (toTy == fromTy) && 332 (fromTy.getWidth() == 1)) 333 return inner.value(); 334 } 335 return {}; 336 } 337 338 bool fir::ConvertOp::isIntegerCompatible(mlir::Type ty) { 339 return ty.isa<mlir::IntegerType>() || ty.isa<mlir::IndexType>() || 340 ty.isa<fir::IntType>() || ty.isa<fir::LogicalType>() || 341 ty.isa<fir::CharacterType>(); 342 } 343 344 bool fir::ConvertOp::isFloatCompatible(mlir::Type ty) { 345 return ty.isa<mlir::FloatType>() || ty.isa<fir::RealType>(); 346 } 347 348 bool fir::ConvertOp::isPointerCompatible(mlir::Type ty) { 349 return ty.isa<fir::ReferenceType>() || ty.isa<fir::PointerType>() || 350 ty.isa<fir::HeapType>() || ty.isa<mlir::MemRefType>() || 351 ty.isa<fir::TypeDescType>(); 352 } 353 354 //===----------------------------------------------------------------------===// 355 // CoordinateOp 356 //===----------------------------------------------------------------------===// 357 358 static mlir::ParseResult parseCoordinateOp(mlir::OpAsmParser &parser, 359 mlir::OperationState &result) { 360 llvm::ArrayRef<mlir::Type> allOperandTypes; 361 llvm::ArrayRef<mlir::Type> allResultTypes; 362 llvm::SMLoc allOperandLoc = parser.getCurrentLocation(); 363 llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> allOperands; 364 if (parser.parseOperandList(allOperands)) 365 return failure(); 366 if (parser.parseOptionalAttrDict(result.attributes)) 367 return failure(); 368 if (parser.parseColon()) 369 return failure(); 370 371 mlir::FunctionType funcTy; 372 if (parser.parseType(funcTy)) 373 return failure(); 374 allOperandTypes = funcTy.getInputs(); 375 allResultTypes = funcTy.getResults(); 376 result.addTypes(allResultTypes); 377 if (parser.resolveOperands(allOperands, allOperandTypes, allOperandLoc, 378 result.operands)) 379 return failure(); 380 if (funcTy.getNumInputs()) { 381 // No inputs handled by verify 382 result.addAttribute(fir::CoordinateOp::baseType(), 383 mlir::TypeAttr::get(funcTy.getInput(0))); 384 } 385 return success(); 386 } 387 388 mlir::Type fir::CoordinateOp::getBaseType() { 389 return getAttr(CoordinateOp::baseType()).cast<mlir::TypeAttr>().getValue(); 390 } 391 392 void fir::CoordinateOp::build(OpBuilder &, OperationState &result, 393 mlir::Type resType, ValueRange operands, 394 ArrayRef<NamedAttribute> attrs) { 395 assert(operands.size() >= 1u && "mismatched number of parameters"); 396 result.addOperands(operands); 397 result.addAttribute(fir::CoordinateOp::baseType(), 398 mlir::TypeAttr::get(operands[0].getType())); 399 result.attributes.append(attrs.begin(), attrs.end()); 400 result.addTypes({resType}); 401 } 402 403 void fir::CoordinateOp::build(OpBuilder &builder, OperationState &result, 404 mlir::Type resType, mlir::Value ref, 405 ValueRange coor, ArrayRef<NamedAttribute> attrs) { 406 llvm::SmallVector<mlir::Value, 16> operands{ref}; 407 operands.append(coor.begin(), coor.end()); 408 build(builder, result, resType, operands, attrs); 409 } 410 411 //===----------------------------------------------------------------------===// 412 // DispatchOp 413 //===----------------------------------------------------------------------===// 414 415 mlir::FunctionType fir::DispatchOp::getFunctionType() { 416 auto attr = getAttr("fn_type").cast<mlir::TypeAttr>(); 417 return attr.getValue().cast<mlir::FunctionType>(); 418 } 419 420 //===----------------------------------------------------------------------===// 421 // DispatchTableOp 422 //===----------------------------------------------------------------------===// 423 424 void fir::DispatchTableOp::appendTableEntry(mlir::Operation *op) { 425 assert(mlir::isa<fir::DTEntryOp>(*op) && "operation must be a DTEntryOp"); 426 auto &block = getBlock(); 427 block.getOperations().insert(block.end(), op); 428 } 429 430 //===----------------------------------------------------------------------===// 431 // EmboxOp 432 //===----------------------------------------------------------------------===// 433 434 static mlir::ParseResult parseEmboxOp(mlir::OpAsmParser &parser, 435 mlir::OperationState &result) { 436 mlir::FunctionType type; 437 llvm::SmallVector<mlir::OpAsmParser::OperandType, 8> operands; 438 mlir::OpAsmParser::OperandType memref; 439 if (parser.parseOperand(memref)) 440 return mlir::failure(); 441 operands.push_back(memref); 442 auto &builder = parser.getBuilder(); 443 if (!parser.parseOptionalLParen()) { 444 if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None) || 445 parser.parseRParen()) 446 return mlir::failure(); 447 auto lens = builder.getI32IntegerAttr(operands.size()); 448 result.addAttribute(fir::EmboxOp::lenpName(), lens); 449 } 450 if (!parser.parseOptionalComma()) { 451 mlir::OpAsmParser::OperandType dims; 452 if (parser.parseOperand(dims)) 453 return mlir::failure(); 454 operands.push_back(dims); 455 } else if (!parser.parseOptionalLSquare()) { 456 mlir::AffineMapAttr map; 457 if (parser.parseAttribute(map, fir::EmboxOp::layoutName(), 458 result.attributes) || 459 parser.parseRSquare()) 460 return mlir::failure(); 461 } 462 if (parser.parseOptionalAttrDict(result.attributes) || 463 parser.parseColonType(type) || 464 parser.resolveOperands(operands, type.getInputs(), parser.getNameLoc(), 465 result.operands) || 466 parser.addTypesToList(type.getResults(), result.types)) 467 return mlir::failure(); 468 return mlir::success(); 469 } 470 471 //===----------------------------------------------------------------------===// 472 // GenTypeDescOp 473 //===----------------------------------------------------------------------===// 474 475 void fir::GenTypeDescOp::build(OpBuilder &, OperationState &result, 476 mlir::TypeAttr inty) { 477 result.addAttribute("in_type", inty); 478 result.addTypes(TypeDescType::get(inty.getValue())); 479 } 480 481 //===----------------------------------------------------------------------===// 482 // GlobalOp 483 //===----------------------------------------------------------------------===// 484 485 static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) { 486 // Parse the optional linkage 487 llvm::StringRef linkage; 488 auto &builder = parser.getBuilder(); 489 if (mlir::succeeded(parser.parseOptionalKeyword(&linkage))) { 490 if (fir::GlobalOp::verifyValidLinkage(linkage)) 491 return failure(); 492 mlir::StringAttr linkAttr = builder.getStringAttr(linkage); 493 result.addAttribute(fir::GlobalOp::linkageAttrName(), linkAttr); 494 } 495 496 // Parse the name as a symbol reference attribute. 497 mlir::SymbolRefAttr nameAttr; 498 if (parser.parseAttribute(nameAttr, fir::GlobalOp::symbolAttrName(), 499 result.attributes)) 500 return failure(); 501 result.addAttribute(mlir::SymbolTable::getSymbolAttrName(), 502 builder.getStringAttr(nameAttr.getRootReference())); 503 504 bool simpleInitializer = false; 505 if (mlir::succeeded(parser.parseOptionalLParen())) { 506 Attribute attr; 507 if (parser.parseAttribute(attr, fir::GlobalOp::initValAttrName(), 508 result.attributes) || 509 parser.parseRParen()) 510 return failure(); 511 simpleInitializer = true; 512 } 513 514 if (succeeded(parser.parseOptionalKeyword("constant"))) { 515 // if "constant" keyword then mark this as a constant, not a variable 516 result.addAttribute(fir::GlobalOp::constantAttrName(), 517 builder.getUnitAttr()); 518 } 519 520 mlir::Type globalType; 521 if (parser.parseColonType(globalType)) 522 return failure(); 523 524 result.addAttribute(fir::GlobalOp::typeAttrName(), 525 mlir::TypeAttr::get(globalType)); 526 527 if (simpleInitializer) { 528 result.addRegion(); 529 } else { 530 // Parse the optional initializer body. 531 if (parser.parseRegion(*result.addRegion(), llvm::None, llvm::None)) 532 return failure(); 533 } 534 535 return success(); 536 } 537 538 void fir::GlobalOp::appendInitialValue(mlir::Operation *op) { 539 getBlock().getOperations().push_back(op); 540 } 541 542 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 543 StringRef name, bool isConstant, Type type, 544 Attribute initialVal, StringAttr linkage, 545 ArrayRef<NamedAttribute> attrs) { 546 result.addRegion(); 547 result.addAttribute(typeAttrName(), mlir::TypeAttr::get(type)); 548 result.addAttribute(mlir::SymbolTable::getSymbolAttrName(), 549 builder.getStringAttr(name)); 550 result.addAttribute(symbolAttrName(), builder.getSymbolRefAttr(name)); 551 if (isConstant) 552 result.addAttribute(constantAttrName(), builder.getUnitAttr()); 553 if (initialVal) 554 result.addAttribute(initValAttrName(), initialVal); 555 if (linkage) 556 result.addAttribute(linkageAttrName(), linkage); 557 result.attributes.append(attrs.begin(), attrs.end()); 558 } 559 560 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 561 StringRef name, Type type, Attribute initialVal, 562 StringAttr linkage, ArrayRef<NamedAttribute> attrs) { 563 build(builder, result, name, /*isConstant=*/false, type, {}, linkage, attrs); 564 } 565 566 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 567 StringRef name, bool isConstant, Type type, 568 StringAttr linkage, ArrayRef<NamedAttribute> attrs) { 569 build(builder, result, name, isConstant, type, {}, linkage, attrs); 570 } 571 572 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 573 StringRef name, Type type, StringAttr linkage, 574 ArrayRef<NamedAttribute> attrs) { 575 build(builder, result, name, /*isConstant=*/false, type, {}, linkage, attrs); 576 } 577 578 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 579 StringRef name, bool isConstant, Type type, 580 ArrayRef<NamedAttribute> attrs) { 581 build(builder, result, name, isConstant, type, StringAttr{}, attrs); 582 } 583 584 void fir::GlobalOp::build(mlir::OpBuilder &builder, OperationState &result, 585 StringRef name, Type type, 586 ArrayRef<NamedAttribute> attrs) { 587 build(builder, result, name, /*isConstant=*/false, type, attrs); 588 } 589 590 mlir::ParseResult fir::GlobalOp::verifyValidLinkage(StringRef linkage) { 591 // Supporting only a subset of the LLVM linkage types for now 592 static const llvm::SmallVector<const char *, 3> validNames = { 593 "internal", "common", "weak"}; 594 return mlir::success(llvm::is_contained(validNames, linkage)); 595 } 596 597 //===----------------------------------------------------------------------===// 598 // IterWhileOp 599 //===----------------------------------------------------------------------===// 600 601 void fir::IterWhileOp::build(mlir::OpBuilder &builder, 602 mlir::OperationState &result, mlir::Value lb, 603 mlir::Value ub, mlir::Value step, 604 mlir::Value iterate, mlir::ValueRange iterArgs, 605 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 606 result.addOperands({lb, ub, step, iterate}); 607 result.addTypes(iterate.getType()); 608 result.addOperands(iterArgs); 609 for (auto v : iterArgs) 610 result.addTypes(v.getType()); 611 mlir::Region *bodyRegion = result.addRegion(); 612 bodyRegion->push_back(new Block{}); 613 bodyRegion->front().addArgument(builder.getIndexType()); 614 bodyRegion->front().addArgument(iterate.getType()); 615 bodyRegion->front().addArguments(iterArgs.getTypes()); 616 result.addAttributes(attributes); 617 } 618 619 static mlir::ParseResult parseIterWhileOp(mlir::OpAsmParser &parser, 620 mlir::OperationState &result) { 621 auto &builder = parser.getBuilder(); 622 mlir::OpAsmParser::OperandType inductionVariable, lb, ub, step; 623 if (parser.parseLParen() || parser.parseRegionArgument(inductionVariable) || 624 parser.parseEqual()) 625 return mlir::failure(); 626 627 // Parse loop bounds. 628 auto indexType = builder.getIndexType(); 629 auto i1Type = builder.getIntegerType(1); 630 if (parser.parseOperand(lb) || 631 parser.resolveOperand(lb, indexType, result.operands) || 632 parser.parseKeyword("to") || parser.parseOperand(ub) || 633 parser.resolveOperand(ub, indexType, result.operands) || 634 parser.parseKeyword("step") || parser.parseOperand(step) || 635 parser.parseRParen() || 636 parser.resolveOperand(step, indexType, result.operands)) 637 return mlir::failure(); 638 639 mlir::OpAsmParser::OperandType iterateVar, iterateInput; 640 if (parser.parseKeyword("and") || parser.parseLParen() || 641 parser.parseRegionArgument(iterateVar) || parser.parseEqual() || 642 parser.parseOperand(iterateInput) || parser.parseRParen() || 643 parser.resolveOperand(iterateInput, i1Type, result.operands)) 644 return mlir::failure(); 645 646 // Parse the initial iteration arguments. 647 llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> regionArgs; 648 // Induction variable. 649 regionArgs.push_back(inductionVariable); 650 regionArgs.push_back(iterateVar); 651 result.addTypes(i1Type); 652 653 if (mlir::succeeded(parser.parseOptionalKeyword("iter_args"))) { 654 llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> operands; 655 llvm::SmallVector<mlir::Type, 4> regionTypes; 656 // Parse assignment list and results type list. 657 if (parser.parseAssignmentList(regionArgs, operands) || 658 parser.parseArrowTypeList(regionTypes)) 659 return mlir::failure(); 660 // Resolve input operands. 661 for (auto operand_type : llvm::zip(operands, regionTypes)) 662 if (parser.resolveOperand(std::get<0>(operand_type), 663 std::get<1>(operand_type), result.operands)) 664 return mlir::failure(); 665 result.addTypes(regionTypes); 666 } 667 668 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 669 return mlir::failure(); 670 671 llvm::SmallVector<mlir::Type, 4> argTypes; 672 // Induction variable (hidden) 673 argTypes.push_back(indexType); 674 // Loop carried variables (including iterate) 675 argTypes.append(result.types.begin(), result.types.end()); 676 // Parse the body region. 677 auto *body = result.addRegion(); 678 if (regionArgs.size() != argTypes.size()) 679 return parser.emitError( 680 parser.getNameLoc(), 681 "mismatch in number of loop-carried values and defined values"); 682 683 if (parser.parseRegion(*body, regionArgs, argTypes)) 684 return failure(); 685 686 fir::IterWhileOp::ensureTerminator(*body, builder, result.location); 687 688 return mlir::success(); 689 } 690 691 static mlir::LogicalResult verify(fir::IterWhileOp op) { 692 if (auto cst = dyn_cast_or_null<ConstantIndexOp>(op.step().getDefiningOp())) 693 if (cst.getValue() <= 0) 694 return op.emitOpError("constant step operand must be positive"); 695 696 // Check that the body defines as single block argument for the induction 697 // variable. 698 auto *body = op.getBody(); 699 if (!body->getArgument(1).getType().isInteger(1)) 700 return op.emitOpError( 701 "expected body second argument to be an index argument for " 702 "the induction variable"); 703 if (!body->getArgument(0).getType().isIndex()) 704 return op.emitOpError( 705 "expected body first argument to be an index argument for " 706 "the induction variable"); 707 708 auto opNumResults = op.getNumResults(); 709 if (opNumResults == 0) 710 return mlir::failure(); 711 if (op.getNumIterOperands() != opNumResults) 712 return op.emitOpError( 713 "mismatch in number of loop-carried values and defined values"); 714 if (op.getNumRegionIterArgs() != opNumResults) 715 return op.emitOpError( 716 "mismatch in number of basic block args and defined values"); 717 auto iterOperands = op.getIterOperands(); 718 auto iterArgs = op.getRegionIterArgs(); 719 auto opResults = op.getResults(); 720 unsigned i = 0; 721 for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) { 722 if (std::get<0>(e).getType() != std::get<2>(e).getType()) 723 return op.emitOpError() << "types mismatch between " << i 724 << "th iter operand and defined value"; 725 if (std::get<1>(e).getType() != std::get<2>(e).getType()) 726 return op.emitOpError() << "types mismatch between " << i 727 << "th iter region arg and defined value"; 728 729 i++; 730 } 731 return mlir::success(); 732 } 733 734 static void print(mlir::OpAsmPrinter &p, fir::IterWhileOp op) { 735 p << fir::IterWhileOp::getOperationName() << " (" << op.getInductionVar() 736 << " = " << op.lowerBound() << " to " << op.upperBound() << " step " 737 << op.step() << ") and ("; 738 assert(op.hasIterOperands()); 739 auto regionArgs = op.getRegionIterArgs(); 740 auto operands = op.getIterOperands(); 741 p << regionArgs.front() << " = " << *operands.begin() << ")"; 742 if (regionArgs.size() > 1) { 743 p << " iter_args("; 744 llvm::interleaveComma( 745 llvm::zip(regionArgs.drop_front(), operands.drop_front()), p, 746 [&](auto it) { p << std::get<0>(it) << " = " << std::get<1>(it); }); 747 p << ") -> (" << op.getResultTypes().drop_front() << ')'; 748 } 749 p.printOptionalAttrDictWithKeyword(op.getAttrs(), {}); 750 p.printRegion(op.region(), /*printEntryBlockArgs=*/false, 751 /*printBlockTerminators=*/true); 752 } 753 754 mlir::Region &fir::IterWhileOp::getLoopBody() { return region(); } 755 756 bool fir::IterWhileOp::isDefinedOutsideOfLoop(mlir::Value value) { 757 return !region().isAncestor(value.getParentRegion()); 758 } 759 760 mlir::LogicalResult 761 fir::IterWhileOp::moveOutOfLoop(llvm::ArrayRef<mlir::Operation *> ops) { 762 for (auto op : ops) 763 op->moveBefore(*this); 764 return success(); 765 } 766 767 //===----------------------------------------------------------------------===// 768 // LoadOp 769 //===----------------------------------------------------------------------===// 770 771 /// Get the element type of a reference like type; otherwise null 772 static mlir::Type elementTypeOf(mlir::Type ref) { 773 return llvm::TypeSwitch<mlir::Type, mlir::Type>(ref) 774 .Case<ReferenceType, PointerType, HeapType>( 775 [](auto type) { return type.getEleTy(); }) 776 .Default([](mlir::Type) { return mlir::Type{}; }); 777 } 778 779 mlir::ParseResult fir::LoadOp::getElementOf(mlir::Type &ele, mlir::Type ref) { 780 if ((ele = elementTypeOf(ref))) 781 return mlir::success(); 782 return mlir::failure(); 783 } 784 785 //===----------------------------------------------------------------------===// 786 // LoopOp 787 //===----------------------------------------------------------------------===// 788 789 void fir::LoopOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, 790 mlir::Value lb, mlir::Value ub, mlir::Value step, 791 bool unordered, mlir::ValueRange iterArgs, 792 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 793 result.addOperands({lb, ub, step}); 794 result.addOperands(iterArgs); 795 for (auto v : iterArgs) 796 result.addTypes(v.getType()); 797 mlir::Region *bodyRegion = result.addRegion(); 798 bodyRegion->push_back(new Block{}); 799 if (iterArgs.empty()) 800 LoopOp::ensureTerminator(*bodyRegion, builder, result.location); 801 bodyRegion->front().addArgument(builder.getIndexType()); 802 bodyRegion->front().addArguments(iterArgs.getTypes()); 803 if (unordered) 804 result.addAttribute(unorderedAttrName(), builder.getUnitAttr()); 805 result.addAttributes(attributes); 806 } 807 808 static mlir::ParseResult parseLoopOp(mlir::OpAsmParser &parser, 809 mlir::OperationState &result) { 810 auto &builder = parser.getBuilder(); 811 mlir::OpAsmParser::OperandType inductionVariable, lb, ub, step; 812 // Parse the induction variable followed by '='. 813 if (parser.parseRegionArgument(inductionVariable) || parser.parseEqual()) 814 return mlir::failure(); 815 816 // Parse loop bounds. 817 auto indexType = builder.getIndexType(); 818 if (parser.parseOperand(lb) || 819 parser.resolveOperand(lb, indexType, result.operands) || 820 parser.parseKeyword("to") || parser.parseOperand(ub) || 821 parser.resolveOperand(ub, indexType, result.operands) || 822 parser.parseKeyword("step") || parser.parseOperand(step) || 823 parser.resolveOperand(step, indexType, result.operands)) 824 return failure(); 825 826 if (mlir::succeeded(parser.parseOptionalKeyword("unordered"))) 827 result.addAttribute(fir::LoopOp::unorderedAttrName(), 828 builder.getUnitAttr()); 829 830 // Parse the optional initial iteration arguments. 831 llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> regionArgs, operands; 832 llvm::SmallVector<mlir::Type, 4> argTypes; 833 regionArgs.push_back(inductionVariable); 834 835 if (succeeded(parser.parseOptionalKeyword("iter_args"))) { 836 // Parse assignment list and results type list. 837 if (parser.parseAssignmentList(regionArgs, operands) || 838 parser.parseArrowTypeList(result.types)) 839 return failure(); 840 // Resolve input operands. 841 for (auto operand_type : llvm::zip(operands, result.types)) 842 if (parser.resolveOperand(std::get<0>(operand_type), 843 std::get<1>(operand_type), result.operands)) 844 return failure(); 845 } 846 847 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 848 return mlir::failure(); 849 850 // Induction variable. 851 argTypes.push_back(indexType); 852 // Loop carried variables 853 argTypes.append(result.types.begin(), result.types.end()); 854 // Parse the body region. 855 auto *body = result.addRegion(); 856 if (regionArgs.size() != argTypes.size()) 857 return parser.emitError( 858 parser.getNameLoc(), 859 "mismatch in number of loop-carried values and defined values"); 860 861 if (parser.parseRegion(*body, regionArgs, argTypes)) 862 return failure(); 863 864 fir::LoopOp::ensureTerminator(*body, builder, result.location); 865 866 return mlir::success(); 867 } 868 869 fir::LoopOp fir::getForInductionVarOwner(mlir::Value val) { 870 auto ivArg = val.dyn_cast<mlir::BlockArgument>(); 871 if (!ivArg) 872 return {}; 873 assert(ivArg.getOwner() && "unlinked block argument"); 874 auto *containingInst = ivArg.getOwner()->getParentOp(); 875 return dyn_cast_or_null<fir::LoopOp>(containingInst); 876 } 877 878 // Lifted from loop.loop 879 static mlir::LogicalResult verify(fir::LoopOp op) { 880 if (auto cst = dyn_cast_or_null<ConstantIndexOp>(op.step().getDefiningOp())) 881 if (cst.getValue() <= 0) 882 return op.emitOpError("constant step operand must be positive"); 883 884 // Check that the body defines as single block argument for the induction 885 // variable. 886 auto *body = op.getBody(); 887 if (!body->getArgument(0).getType().isIndex()) 888 return op.emitOpError( 889 "expected body first argument to be an index argument for " 890 "the induction variable"); 891 892 auto opNumResults = op.getNumResults(); 893 if (opNumResults == 0) 894 return success(); 895 if (op.getNumIterOperands() != opNumResults) 896 return op.emitOpError( 897 "mismatch in number of loop-carried values and defined values"); 898 if (op.getNumRegionIterArgs() != opNumResults) 899 return op.emitOpError( 900 "mismatch in number of basic block args and defined values"); 901 auto iterOperands = op.getIterOperands(); 902 auto iterArgs = op.getRegionIterArgs(); 903 auto opResults = op.getResults(); 904 unsigned i = 0; 905 for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) { 906 if (std::get<0>(e).getType() != std::get<2>(e).getType()) 907 return op.emitOpError() << "types mismatch between " << i 908 << "th iter operand and defined value"; 909 if (std::get<1>(e).getType() != std::get<2>(e).getType()) 910 return op.emitOpError() << "types mismatch between " << i 911 << "th iter region arg and defined value"; 912 913 i++; 914 } 915 return success(); 916 } 917 918 static void print(mlir::OpAsmPrinter &p, fir::LoopOp op) { 919 bool printBlockTerminators = false; 920 p << fir::LoopOp::getOperationName() << ' ' << op.getInductionVar() << " = " 921 << op.lowerBound() << " to " << op.upperBound() << " step " << op.step(); 922 if (op.unordered()) 923 p << " unordered"; 924 if (op.hasIterOperands()) { 925 p << " iter_args("; 926 auto regionArgs = op.getRegionIterArgs(); 927 auto operands = op.getIterOperands(); 928 llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) { 929 p << std::get<0>(it) << " = " << std::get<1>(it); 930 }); 931 p << ") -> (" << op.getResultTypes() << ')'; 932 printBlockTerminators = true; 933 } 934 p.printOptionalAttrDictWithKeyword(op.getAttrs(), 935 {fir::LoopOp::unorderedAttrName()}); 936 p.printRegion(op.region(), /*printEntryBlockArgs=*/false, 937 printBlockTerminators); 938 } 939 940 mlir::Region &fir::LoopOp::getLoopBody() { return region(); } 941 942 bool fir::LoopOp::isDefinedOutsideOfLoop(mlir::Value value) { 943 return !region().isAncestor(value.getParentRegion()); 944 } 945 946 mlir::LogicalResult 947 fir::LoopOp::moveOutOfLoop(llvm::ArrayRef<mlir::Operation *> ops) { 948 for (auto op : ops) 949 op->moveBefore(*this); 950 return success(); 951 } 952 953 //===----------------------------------------------------------------------===// 954 // MulfOp 955 //===----------------------------------------------------------------------===// 956 957 mlir::OpFoldResult fir::MulfOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 958 return mlir::constFoldBinaryOp<FloatAttr>( 959 opnds, [](APFloat a, APFloat b) { return a * b; }); 960 } 961 962 //===----------------------------------------------------------------------===// 963 // ResultOp 964 //===----------------------------------------------------------------------===// 965 966 static mlir::LogicalResult verify(fir::ResultOp op) { 967 auto parentOp = op.getParentOp(); 968 auto results = parentOp->getResults(); 969 auto operands = op.getOperands(); 970 971 if (parentOp->getNumResults() != op.getNumOperands()) 972 return op.emitOpError() << "parent of result must have same arity"; 973 for (auto e : llvm::zip(results, operands)) 974 if (std::get<0>(e).getType() != std::get<1>(e).getType()) 975 return op.emitOpError() 976 << "types mismatch between result op and its parent"; 977 return success(); 978 } 979 980 //===----------------------------------------------------------------------===// 981 // SelectOp 982 //===----------------------------------------------------------------------===// 983 984 static constexpr llvm::StringRef getCompareOffsetAttr() { 985 return "compare_operand_offsets"; 986 } 987 988 static constexpr llvm::StringRef getTargetOffsetAttr() { 989 return "target_operand_offsets"; 990 } 991 992 template <typename A, typename... AdditionalArgs> 993 static A getSubOperands(unsigned pos, A allArgs, 994 mlir::DenseIntElementsAttr ranges, 995 AdditionalArgs &&... additionalArgs) { 996 unsigned start = 0; 997 for (unsigned i = 0; i < pos; ++i) 998 start += (*(ranges.begin() + i)).getZExtValue(); 999 return allArgs.slice(start, (*(ranges.begin() + pos)).getZExtValue(), 1000 std::forward<AdditionalArgs>(additionalArgs)...); 1001 } 1002 1003 static mlir::MutableOperandRange 1004 getMutableSuccessorOperands(unsigned pos, mlir::MutableOperandRange operands, 1005 StringRef offsetAttr) { 1006 Operation *owner = operands.getOwner(); 1007 NamedAttribute targetOffsetAttr = 1008 *owner->getMutableAttrDict().getNamed(offsetAttr); 1009 return getSubOperands( 1010 pos, operands, targetOffsetAttr.second.cast<DenseIntElementsAttr>(), 1011 mlir::MutableOperandRange::OperandSegment(pos, targetOffsetAttr)); 1012 } 1013 1014 static unsigned denseElementsSize(mlir::DenseIntElementsAttr attr) { 1015 return attr.getNumElements(); 1016 } 1017 1018 llvm::Optional<mlir::OperandRange> fir::SelectOp::getCompareOperands(unsigned) { 1019 return {}; 1020 } 1021 1022 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1023 fir::SelectOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) { 1024 return {}; 1025 } 1026 1027 llvm::Optional<mlir::MutableOperandRange> 1028 fir::SelectOp::getMutableSuccessorOperands(unsigned oper) { 1029 return ::getMutableSuccessorOperands(oper, targetArgsMutable(), 1030 getTargetOffsetAttr()); 1031 } 1032 1033 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1034 fir::SelectOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, 1035 unsigned oper) { 1036 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 1037 auto segments = 1038 getAttrOfType<mlir::DenseIntElementsAttr>(getOperandSegmentSizeAttr()); 1039 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 1040 } 1041 1042 unsigned fir::SelectOp::targetOffsetSize() { 1043 return denseElementsSize( 1044 getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr())); 1045 } 1046 1047 //===----------------------------------------------------------------------===// 1048 // SelectCaseOp 1049 //===----------------------------------------------------------------------===// 1050 1051 llvm::Optional<mlir::OperandRange> 1052 fir::SelectCaseOp::getCompareOperands(unsigned cond) { 1053 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getCompareOffsetAttr()); 1054 return {getSubOperands(cond, compareArgs(), a)}; 1055 } 1056 1057 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1058 fir::SelectCaseOp::getCompareOperands(llvm::ArrayRef<mlir::Value> operands, 1059 unsigned cond) { 1060 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getCompareOffsetAttr()); 1061 auto segments = 1062 getAttrOfType<mlir::DenseIntElementsAttr>(getOperandSegmentSizeAttr()); 1063 return {getSubOperands(cond, getSubOperands(1, operands, segments), a)}; 1064 } 1065 1066 llvm::Optional<mlir::MutableOperandRange> 1067 fir::SelectCaseOp::getMutableSuccessorOperands(unsigned oper) { 1068 return ::getMutableSuccessorOperands(oper, targetArgsMutable(), 1069 getTargetOffsetAttr()); 1070 } 1071 1072 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1073 fir::SelectCaseOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, 1074 unsigned oper) { 1075 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 1076 auto segments = 1077 getAttrOfType<mlir::DenseIntElementsAttr>(getOperandSegmentSizeAttr()); 1078 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 1079 } 1080 1081 // parser for fir.select_case Op 1082 static mlir::ParseResult parseSelectCase(mlir::OpAsmParser &parser, 1083 mlir::OperationState &result) { 1084 mlir::OpAsmParser::OperandType selector; 1085 mlir::Type type; 1086 if (parseSelector(parser, result, selector, type)) 1087 return mlir::failure(); 1088 1089 llvm::SmallVector<mlir::Attribute, 8> attrs; 1090 llvm::SmallVector<mlir::OpAsmParser::OperandType, 8> opers; 1091 llvm::SmallVector<mlir::Block *, 8> dests; 1092 llvm::SmallVector<llvm::SmallVector<mlir::Value, 8>, 8> destArgs; 1093 llvm::SmallVector<int32_t, 8> argOffs; 1094 int32_t offSize = 0; 1095 while (true) { 1096 mlir::Attribute attr; 1097 mlir::Block *dest; 1098 llvm::SmallVector<mlir::Value, 8> destArg; 1099 mlir::NamedAttrList temp; 1100 if (parser.parseAttribute(attr, "a", temp) || isValidCaseAttr(attr) || 1101 parser.parseComma()) 1102 return mlir::failure(); 1103 attrs.push_back(attr); 1104 if (attr.dyn_cast_or_null<mlir::UnitAttr>()) { 1105 argOffs.push_back(0); 1106 } else if (attr.dyn_cast_or_null<fir::ClosedIntervalAttr>()) { 1107 mlir::OpAsmParser::OperandType oper1; 1108 mlir::OpAsmParser::OperandType oper2; 1109 if (parser.parseOperand(oper1) || parser.parseComma() || 1110 parser.parseOperand(oper2) || parser.parseComma()) 1111 return mlir::failure(); 1112 opers.push_back(oper1); 1113 opers.push_back(oper2); 1114 argOffs.push_back(2); 1115 offSize += 2; 1116 } else { 1117 mlir::OpAsmParser::OperandType oper; 1118 if (parser.parseOperand(oper) || parser.parseComma()) 1119 return mlir::failure(); 1120 opers.push_back(oper); 1121 argOffs.push_back(1); 1122 ++offSize; 1123 } 1124 if (parser.parseSuccessorAndUseList(dest, destArg)) 1125 return mlir::failure(); 1126 dests.push_back(dest); 1127 destArgs.push_back(destArg); 1128 if (!parser.parseOptionalRSquare()) 1129 break; 1130 if (parser.parseComma()) 1131 return mlir::failure(); 1132 } 1133 result.addAttribute(fir::SelectCaseOp::getCasesAttr(), 1134 parser.getBuilder().getArrayAttr(attrs)); 1135 if (parser.resolveOperands(opers, type, result.operands)) 1136 return mlir::failure(); 1137 llvm::SmallVector<int32_t, 8> targOffs; 1138 int32_t toffSize = 0; 1139 const auto count = dests.size(); 1140 for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { 1141 result.addSuccessors(dests[i]); 1142 result.addOperands(destArgs[i]); 1143 auto argSize = destArgs[i].size(); 1144 targOffs.push_back(argSize); 1145 toffSize += argSize; 1146 } 1147 auto &bld = parser.getBuilder(); 1148 result.addAttribute(fir::SelectCaseOp::getOperandSegmentSizeAttr(), 1149 bld.getI32VectorAttr({1, offSize, toffSize})); 1150 result.addAttribute(getCompareOffsetAttr(), bld.getI32VectorAttr(argOffs)); 1151 result.addAttribute(getTargetOffsetAttr(), bld.getI32VectorAttr(targOffs)); 1152 return mlir::success(); 1153 } 1154 1155 unsigned fir::SelectCaseOp::compareOffsetSize() { 1156 return denseElementsSize( 1157 getAttrOfType<mlir::DenseIntElementsAttr>(getCompareOffsetAttr())); 1158 } 1159 1160 unsigned fir::SelectCaseOp::targetOffsetSize() { 1161 return denseElementsSize( 1162 getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr())); 1163 } 1164 1165 void fir::SelectCaseOp::build(mlir::OpBuilder &builder, 1166 mlir::OperationState &result, 1167 mlir::Value selector, 1168 llvm::ArrayRef<mlir::Attribute> compareAttrs, 1169 llvm::ArrayRef<mlir::ValueRange> cmpOperands, 1170 llvm::ArrayRef<mlir::Block *> destinations, 1171 llvm::ArrayRef<mlir::ValueRange> destOperands, 1172 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 1173 result.addOperands(selector); 1174 result.addAttribute(getCasesAttr(), builder.getArrayAttr(compareAttrs)); 1175 llvm::SmallVector<int32_t, 8> operOffs; 1176 int32_t operSize = 0; 1177 for (auto attr : compareAttrs) { 1178 if (attr.isa<fir::ClosedIntervalAttr>()) { 1179 operOffs.push_back(2); 1180 operSize += 2; 1181 } else if (attr.isa<mlir::UnitAttr>()) { 1182 operOffs.push_back(0); 1183 } else { 1184 operOffs.push_back(1); 1185 ++operSize; 1186 } 1187 } 1188 for (auto ops : cmpOperands) 1189 result.addOperands(ops); 1190 result.addAttribute(getCompareOffsetAttr(), 1191 builder.getI32VectorAttr(operOffs)); 1192 const auto count = destinations.size(); 1193 for (auto d : destinations) 1194 result.addSuccessors(d); 1195 const auto opCount = destOperands.size(); 1196 llvm::SmallVector<int32_t, 8> argOffs; 1197 int32_t sumArgs = 0; 1198 for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { 1199 if (i < opCount) { 1200 result.addOperands(destOperands[i]); 1201 const auto argSz = destOperands[i].size(); 1202 argOffs.push_back(argSz); 1203 sumArgs += argSz; 1204 } else { 1205 argOffs.push_back(0); 1206 } 1207 } 1208 result.addAttribute(getOperandSegmentSizeAttr(), 1209 builder.getI32VectorAttr({1, operSize, sumArgs})); 1210 result.addAttribute(getTargetOffsetAttr(), builder.getI32VectorAttr(argOffs)); 1211 result.addAttributes(attributes); 1212 } 1213 1214 /// This builder has a slightly simplified interface in that the list of 1215 /// operands need not be partitioned by the builder. Instead the operands are 1216 /// partitioned here, before being passed to the default builder. This 1217 /// partitioning is unchecked, so can go awry on bad input. 1218 void fir::SelectCaseOp::build(mlir::OpBuilder &builder, 1219 mlir::OperationState &result, 1220 mlir::Value selector, 1221 llvm::ArrayRef<mlir::Attribute> compareAttrs, 1222 llvm::ArrayRef<mlir::Value> cmpOpList, 1223 llvm::ArrayRef<mlir::Block *> destinations, 1224 llvm::ArrayRef<mlir::ValueRange> destOperands, 1225 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 1226 llvm::SmallVector<mlir::ValueRange, 16> cmpOpers; 1227 auto iter = cmpOpList.begin(); 1228 for (auto &attr : compareAttrs) { 1229 if (attr.isa<fir::ClosedIntervalAttr>()) { 1230 cmpOpers.push_back(mlir::ValueRange({iter, iter + 2})); 1231 iter += 2; 1232 } else if (attr.isa<UnitAttr>()) { 1233 cmpOpers.push_back(mlir::ValueRange{}); 1234 } else { 1235 cmpOpers.push_back(mlir::ValueRange({iter, iter + 1})); 1236 ++iter; 1237 } 1238 } 1239 build(builder, result, selector, compareAttrs, cmpOpers, destinations, 1240 destOperands, attributes); 1241 } 1242 1243 //===----------------------------------------------------------------------===// 1244 // SelectRankOp 1245 //===----------------------------------------------------------------------===// 1246 1247 llvm::Optional<mlir::OperandRange> 1248 fir::SelectRankOp::getCompareOperands(unsigned) { 1249 return {}; 1250 } 1251 1252 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1253 fir::SelectRankOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) { 1254 return {}; 1255 } 1256 1257 llvm::Optional<mlir::MutableOperandRange> 1258 fir::SelectRankOp::getMutableSuccessorOperands(unsigned oper) { 1259 return ::getMutableSuccessorOperands(oper, targetArgsMutable(), 1260 getTargetOffsetAttr()); 1261 } 1262 1263 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1264 fir::SelectRankOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, 1265 unsigned oper) { 1266 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 1267 auto segments = 1268 getAttrOfType<mlir::DenseIntElementsAttr>(getOperandSegmentSizeAttr()); 1269 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 1270 } 1271 1272 unsigned fir::SelectRankOp::targetOffsetSize() { 1273 return denseElementsSize( 1274 getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr())); 1275 } 1276 1277 //===----------------------------------------------------------------------===// 1278 // SelectTypeOp 1279 //===----------------------------------------------------------------------===// 1280 1281 llvm::Optional<mlir::OperandRange> 1282 fir::SelectTypeOp::getCompareOperands(unsigned) { 1283 return {}; 1284 } 1285 1286 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1287 fir::SelectTypeOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) { 1288 return {}; 1289 } 1290 1291 llvm::Optional<mlir::MutableOperandRange> 1292 fir::SelectTypeOp::getMutableSuccessorOperands(unsigned oper) { 1293 return ::getMutableSuccessorOperands(oper, targetArgsMutable(), 1294 getTargetOffsetAttr()); 1295 } 1296 1297 llvm::Optional<llvm::ArrayRef<mlir::Value>> 1298 fir::SelectTypeOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, 1299 unsigned oper) { 1300 auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()); 1301 auto segments = 1302 getAttrOfType<mlir::DenseIntElementsAttr>(getOperandSegmentSizeAttr()); 1303 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 1304 } 1305 1306 static ParseResult parseSelectType(OpAsmParser &parser, 1307 OperationState &result) { 1308 mlir::OpAsmParser::OperandType selector; 1309 mlir::Type type; 1310 if (parseSelector(parser, result, selector, type)) 1311 return mlir::failure(); 1312 1313 llvm::SmallVector<mlir::Attribute, 8> attrs; 1314 llvm::SmallVector<mlir::Block *, 8> dests; 1315 llvm::SmallVector<llvm::SmallVector<mlir::Value, 8>, 8> destArgs; 1316 while (true) { 1317 mlir::Attribute attr; 1318 mlir::Block *dest; 1319 llvm::SmallVector<mlir::Value, 8> destArg; 1320 mlir::NamedAttrList temp; 1321 if (parser.parseAttribute(attr, "a", temp) || parser.parseComma() || 1322 parser.parseSuccessorAndUseList(dest, destArg)) 1323 return mlir::failure(); 1324 attrs.push_back(attr); 1325 dests.push_back(dest); 1326 destArgs.push_back(destArg); 1327 if (!parser.parseOptionalRSquare()) 1328 break; 1329 if (parser.parseComma()) 1330 return mlir::failure(); 1331 } 1332 auto &bld = parser.getBuilder(); 1333 result.addAttribute(fir::SelectTypeOp::getCasesAttr(), 1334 bld.getArrayAttr(attrs)); 1335 llvm::SmallVector<int32_t, 8> argOffs; 1336 int32_t offSize = 0; 1337 const auto count = dests.size(); 1338 for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { 1339 result.addSuccessors(dests[i]); 1340 result.addOperands(destArgs[i]); 1341 auto argSize = destArgs[i].size(); 1342 argOffs.push_back(argSize); 1343 offSize += argSize; 1344 } 1345 result.addAttribute(fir::SelectTypeOp::getOperandSegmentSizeAttr(), 1346 bld.getI32VectorAttr({1, 0, offSize})); 1347 result.addAttribute(getTargetOffsetAttr(), bld.getI32VectorAttr(argOffs)); 1348 return mlir::success(); 1349 } 1350 1351 unsigned fir::SelectTypeOp::targetOffsetSize() { 1352 return denseElementsSize( 1353 getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr())); 1354 } 1355 1356 //===----------------------------------------------------------------------===// 1357 // StoreOp 1358 //===----------------------------------------------------------------------===// 1359 1360 mlir::Type fir::StoreOp::elementType(mlir::Type refType) { 1361 if (auto ref = refType.dyn_cast<ReferenceType>()) 1362 return ref.getEleTy(); 1363 if (auto ref = refType.dyn_cast<PointerType>()) 1364 return ref.getEleTy(); 1365 if (auto ref = refType.dyn_cast<HeapType>()) 1366 return ref.getEleTy(); 1367 return {}; 1368 } 1369 1370 //===----------------------------------------------------------------------===// 1371 // StringLitOp 1372 //===----------------------------------------------------------------------===// 1373 1374 bool fir::StringLitOp::isWideValue() { 1375 auto eleTy = getType().cast<fir::SequenceType>().getEleTy(); 1376 return eleTy.cast<fir::CharacterType>().getFKind() != 1; 1377 } 1378 1379 //===----------------------------------------------------------------------===// 1380 // SubfOp 1381 //===----------------------------------------------------------------------===// 1382 1383 mlir::OpFoldResult fir::SubfOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) { 1384 return mlir::constFoldBinaryOp<FloatAttr>( 1385 opnds, [](APFloat a, APFloat b) { return a - b; }); 1386 } 1387 1388 //===----------------------------------------------------------------------===// 1389 // WhereOp 1390 //===----------------------------------------------------------------------===// 1391 void fir::WhereOp::build(mlir::OpBuilder &builder, OperationState &result, 1392 mlir::Value cond, bool withElseRegion) { 1393 build(builder, result, llvm::None, cond, withElseRegion); 1394 } 1395 1396 void fir::WhereOp::build(mlir::OpBuilder &builder, OperationState &result, 1397 mlir::TypeRange resultTypes, mlir::Value cond, 1398 bool withElseRegion) { 1399 result.addOperands(cond); 1400 result.addTypes(resultTypes); 1401 1402 mlir::Region *thenRegion = result.addRegion(); 1403 thenRegion->push_back(new mlir::Block()); 1404 if (resultTypes.empty()) 1405 WhereOp::ensureTerminator(*thenRegion, builder, result.location); 1406 1407 mlir::Region *elseRegion = result.addRegion(); 1408 if (withElseRegion) { 1409 elseRegion->push_back(new mlir::Block()); 1410 if (resultTypes.empty()) 1411 WhereOp::ensureTerminator(*elseRegion, builder, result.location); 1412 } 1413 } 1414 1415 static mlir::ParseResult parseWhereOp(OpAsmParser &parser, 1416 OperationState &result) { 1417 result.regions.reserve(2); 1418 mlir::Region *thenRegion = result.addRegion(); 1419 mlir::Region *elseRegion = result.addRegion(); 1420 1421 auto &builder = parser.getBuilder(); 1422 OpAsmParser::OperandType cond; 1423 mlir::Type i1Type = builder.getIntegerType(1); 1424 if (parser.parseOperand(cond) || 1425 parser.resolveOperand(cond, i1Type, result.operands)) 1426 return mlir::failure(); 1427 1428 if (parser.parseRegion(*thenRegion, {}, {})) 1429 return mlir::failure(); 1430 1431 WhereOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location); 1432 1433 if (!parser.parseOptionalKeyword("else")) { 1434 if (parser.parseRegion(*elseRegion, {}, {})) 1435 return mlir::failure(); 1436 WhereOp::ensureTerminator(*elseRegion, parser.getBuilder(), 1437 result.location); 1438 } 1439 1440 // Parse the optional attribute list. 1441 if (parser.parseOptionalAttrDict(result.attributes)) 1442 return mlir::failure(); 1443 1444 return mlir::success(); 1445 } 1446 1447 static LogicalResult verify(fir::WhereOp op) { 1448 if (op.getNumResults() != 0 && op.otherRegion().empty()) 1449 return op.emitOpError("must have an else block if defining values"); 1450 1451 return mlir::success(); 1452 } 1453 1454 static void print(mlir::OpAsmPrinter &p, fir::WhereOp op) { 1455 bool printBlockTerminators = false; 1456 p << fir::WhereOp::getOperationName() << ' ' << op.condition(); 1457 if (!op.results().empty()) { 1458 p << " -> (" << op.getResultTypes() << ')'; 1459 printBlockTerminators = true; 1460 } 1461 p.printRegion(op.whereRegion(), /*printEntryBlockArgs=*/false, 1462 printBlockTerminators); 1463 1464 // Print the 'else' regions if it exists and has a block. 1465 auto &otherReg = op.otherRegion(); 1466 if (!otherReg.empty()) { 1467 p << " else"; 1468 p.printRegion(otherReg, /*printEntryBlockArgs=*/false, 1469 printBlockTerminators); 1470 } 1471 p.printOptionalAttrDict(op.getAttrs()); 1472 } 1473 1474 //===----------------------------------------------------------------------===// 1475 1476 mlir::ParseResult fir::isValidCaseAttr(mlir::Attribute attr) { 1477 if (attr.dyn_cast_or_null<mlir::UnitAttr>() || 1478 attr.dyn_cast_or_null<ClosedIntervalAttr>() || 1479 attr.dyn_cast_or_null<PointIntervalAttr>() || 1480 attr.dyn_cast_or_null<LowerBoundAttr>() || 1481 attr.dyn_cast_or_null<UpperBoundAttr>()) 1482 return mlir::success(); 1483 return mlir::failure(); 1484 } 1485 1486 unsigned fir::getCaseArgumentOffset(llvm::ArrayRef<mlir::Attribute> cases, 1487 unsigned dest) { 1488 unsigned o = 0; 1489 for (unsigned i = 0; i < dest; ++i) { 1490 auto &attr = cases[i]; 1491 if (!attr.dyn_cast_or_null<mlir::UnitAttr>()) { 1492 ++o; 1493 if (attr.dyn_cast_or_null<ClosedIntervalAttr>()) 1494 ++o; 1495 } 1496 } 1497 return o; 1498 } 1499 1500 mlir::ParseResult fir::parseSelector(mlir::OpAsmParser &parser, 1501 mlir::OperationState &result, 1502 mlir::OpAsmParser::OperandType &selector, 1503 mlir::Type &type) { 1504 if (parser.parseOperand(selector) || parser.parseColonType(type) || 1505 parser.resolveOperand(selector, type, result.operands) || 1506 parser.parseLSquare()) 1507 return mlir::failure(); 1508 return mlir::success(); 1509 } 1510 1511 /// Generic pretty-printer of a binary operation 1512 static void printBinaryOp(Operation *op, OpAsmPrinter &p) { 1513 assert(op->getNumOperands() == 2 && "binary op must have two operands"); 1514 assert(op->getNumResults() == 1 && "binary op must have one result"); 1515 1516 p << op->getName() << ' ' << op->getOperand(0) << ", " << op->getOperand(1); 1517 p.printOptionalAttrDict(op->getAttrs()); 1518 p << " : " << op->getResult(0).getType(); 1519 } 1520 1521 /// Generic pretty-printer of an unary operation 1522 static void printUnaryOp(Operation *op, OpAsmPrinter &p) { 1523 assert(op->getNumOperands() == 1 && "unary op must have one operand"); 1524 assert(op->getNumResults() == 1 && "unary op must have one result"); 1525 1526 p << op->getName() << ' ' << op->getOperand(0); 1527 p.printOptionalAttrDict(op->getAttrs()); 1528 p << " : " << op->getResult(0).getType(); 1529 } 1530 1531 bool fir::isReferenceLike(mlir::Type type) { 1532 return type.isa<fir::ReferenceType>() || type.isa<fir::HeapType>() || 1533 type.isa<fir::PointerType>(); 1534 } 1535 1536 mlir::FuncOp fir::createFuncOp(mlir::Location loc, mlir::ModuleOp module, 1537 StringRef name, mlir::FunctionType type, 1538 llvm::ArrayRef<mlir::NamedAttribute> attrs) { 1539 if (auto f = module.lookupSymbol<mlir::FuncOp>(name)) 1540 return f; 1541 mlir::OpBuilder modBuilder(module.getBodyRegion()); 1542 modBuilder.setInsertionPoint(module.getBody()->getTerminator()); 1543 return modBuilder.create<mlir::FuncOp>(loc, name, type, attrs); 1544 } 1545 1546 fir::GlobalOp fir::createGlobalOp(mlir::Location loc, mlir::ModuleOp module, 1547 StringRef name, mlir::Type type, 1548 llvm::ArrayRef<mlir::NamedAttribute> attrs) { 1549 if (auto g = module.lookupSymbol<fir::GlobalOp>(name)) 1550 return g; 1551 mlir::OpBuilder modBuilder(module.getBodyRegion()); 1552 return modBuilder.create<fir::GlobalOp>(loc, name, type, attrs); 1553 } 1554 1555 // Tablegen operators 1556 1557 #define GET_OP_CLASSES 1558 #include "flang/Optimizer/Dialect/FIROps.cpp.inc" 1559 1560